mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -31,26 +31,20 @@ def _print_local_metrics(metrics: dict[str, Any]) -> None:
|
||||
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
||||
|
||||
|
||||
def _should_print_local(spec: TrainSpec) -> bool:
|
||||
if not spec.runtime.use_jax:
|
||||
return True
|
||||
try:
|
||||
import jax
|
||||
|
||||
return int(jax.process_index()) == 0
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
|
||||
if not spec.runtime.use_jax:
|
||||
return False
|
||||
try:
|
||||
import jax
|
||||
|
||||
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
|
||||
except Exception:
|
||||
return False
|
||||
def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None:
|
||||
if not events:
|
||||
return
|
||||
period = max(1, int(log_freq))
|
||||
last_logged_step = -period
|
||||
for event in sorted(
|
||||
[evt for evt in events if isinstance(evt, dict)],
|
||||
key=lambda evt: int(evt.get("train/global_step", 0)),
|
||||
):
|
||||
step = int(event.get("train/global_step", 0))
|
||||
if step <= 0 or (step - last_logged_step) < period:
|
||||
continue
|
||||
log_metrics(event, step=step)
|
||||
last_logged_step = step
|
||||
|
||||
|
||||
def run_train_once(
|
||||
@@ -65,10 +59,9 @@ def run_train_once(
|
||||
extra_tags: Sequence[str],
|
||||
) -> dict[str, Any]:
|
||||
wandb = get_wandb_module()
|
||||
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
|
||||
if no_wandb or wandb is None:
|
||||
result = run_train(spec)
|
||||
if _should_print_local(spec):
|
||||
_print_local_metrics(result.metrics)
|
||||
_print_local_metrics(result.metrics)
|
||||
return result.metrics
|
||||
|
||||
mode = "offline" if offline else "online"
|
||||
@@ -95,6 +88,7 @@ def run_train_once(
|
||||
|
||||
try:
|
||||
result = run_train(spec)
|
||||
_log_train_events(result.events, spec.runtime.log_freq)
|
||||
metrics = result.metrics
|
||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||
log_metrics(metrics, step=step)
|
||||
@@ -122,6 +116,7 @@ def run_with_active_sweep_run(
|
||||
)
|
||||
update_run_config({**spec.to_flat_dict(), **metadata})
|
||||
result = run_train(spec)
|
||||
_log_train_events(result.events, spec.runtime.log_freq)
|
||||
metrics = result.metrics
|
||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||
log_metrics(metrics, step=step)
|
||||
|
||||
Reference in New Issue
Block a user