mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
130 lines
3.1 KiB
Python
130 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any, Sequence
|
|
|
|
from ..spec import TrainSpec, run_metadata, run_name
|
|
from ..telemetry.wandb import (
|
|
finish_run,
|
|
get_wandb_module,
|
|
init_run,
|
|
log_metrics,
|
|
update_run_config,
|
|
update_summary,
|
|
)
|
|
from ..train_core import run_train
|
|
|
|
|
|
def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list[str]:
|
|
tags = [
|
|
kind,
|
|
spec.algorithm.name,
|
|
spec.runtime.backend,
|
|
"vanilla" if spec.study.no_robust else "robust",
|
|
]
|
|
tags.extend([tag for tag in extra_tags if tag])
|
|
return tags
|
|
|
|
|
|
def _print_local_metrics(metrics: dict[str, Any]) -> None:
|
|
print(json.dumps(metrics, indent=2))
|
|
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
|
|
|
|
|
def _should_print_local(spec: TrainSpec) -> bool:
|
|
if not spec.runtime.use_jax:
|
|
return True
|
|
try:
|
|
import jax
|
|
|
|
return int(jax.process_index()) == 0
|
|
except Exception:
|
|
return True
|
|
|
|
|
|
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
|
|
if not spec.runtime.use_jax:
|
|
return False
|
|
try:
|
|
import jax
|
|
|
|
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def run_train_once(
|
|
spec: TrainSpec,
|
|
*,
|
|
project: str,
|
|
offline: bool,
|
|
no_wandb: bool,
|
|
kind: str,
|
|
scenario: str,
|
|
group: str | None,
|
|
extra_tags: Sequence[str],
|
|
) -> dict[str, Any]:
|
|
wandb = get_wandb_module()
|
|
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
|
|
result = run_train(spec)
|
|
if _should_print_local(spec):
|
|
_print_local_metrics(result.metrics)
|
|
return result.metrics
|
|
|
|
mode = "offline" if offline else "online"
|
|
tags = _tags_for_run(spec, kind, extra_tags)
|
|
metadata = run_metadata(
|
|
spec,
|
|
kind=kind,
|
|
scenario=scenario,
|
|
group=group,
|
|
tags=tags,
|
|
)
|
|
config = spec.to_flat_dict()
|
|
config.update(metadata)
|
|
name = run_name(spec, kind=kind, scenario=scenario)
|
|
init_run(
|
|
mode=mode,
|
|
project=project,
|
|
config=config,
|
|
name=name,
|
|
tags=tags,
|
|
group=group,
|
|
sweep_mode=False,
|
|
)
|
|
|
|
try:
|
|
result = run_train(spec)
|
|
metrics = result.metrics
|
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
|
log_metrics(metrics, step=step)
|
|
update_summary(metrics)
|
|
return metrics
|
|
finally:
|
|
finish_run()
|
|
|
|
|
|
def run_with_active_sweep_run(
|
|
spec: TrainSpec,
|
|
*,
|
|
kind: str,
|
|
scenario: str,
|
|
group: str | None,
|
|
extra_tags: Sequence[str],
|
|
) -> dict[str, Any]:
|
|
tags = _tags_for_run(spec, kind, extra_tags)
|
|
metadata = run_metadata(
|
|
spec,
|
|
kind=kind,
|
|
scenario=scenario,
|
|
group=group,
|
|
tags=tags,
|
|
)
|
|
update_run_config({**spec.to_flat_dict(), **metadata})
|
|
result = run_train(spec)
|
|
metrics = result.metrics
|
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
|
log_metrics(metrics, step=step)
|
|
update_summary(metrics)
|
|
return metrics
|