cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -31,26 +31,20 @@ def _print_local_metrics(metrics: dict[str, Any]) -> None:
print("PHANTOM_METRICS:" + json.dumps(metrics))
def _should_print_local(spec: TrainSpec) -> bool:
if not spec.runtime.use_jax:
return True
try:
import jax
return int(jax.process_index()) == 0
except Exception:
return True
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
if not spec.runtime.use_jax:
return False
try:
import jax
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
except Exception:
return False
def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None:
if not events:
return
period = max(1, int(log_freq))
last_logged_step = -period
for event in sorted(
[evt for evt in events if isinstance(evt, dict)],
key=lambda evt: int(evt.get("train/global_step", 0)),
):
step = int(event.get("train/global_step", 0))
if step <= 0 or (step - last_logged_step) < period:
continue
log_metrics(event, step=step)
last_logged_step = step
def run_train_once(
@@ -65,10 +59,9 @@ def run_train_once(
extra_tags: Sequence[str],
) -> dict[str, Any]:
wandb = get_wandb_module()
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
if no_wandb or wandb is None:
result = run_train(spec)
if _should_print_local(spec):
_print_local_metrics(result.metrics)
_print_local_metrics(result.metrics)
return result.metrics
mode = "offline" if offline else "online"
@@ -95,6 +88,7 @@ def run_train_once(
try:
result = run_train(spec)
_log_train_events(result.events, spec.runtime.log_freq)
metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step)
@@ -122,6 +116,7 @@ def run_with_active_sweep_run(
)
update_run_config({**spec.to_flat_dict(), **metadata})
result = run_train(spec)
_log_train_events(result.events, spec.runtime.log_freq)
metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step)