mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
125 lines
3.2 KiB
Python
125 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any, Sequence
|
|
|
|
from ..spec import TrainSpec, run_metadata, run_name
|
|
from ..telemetry.wandb import (
|
|
finish_run,
|
|
get_wandb_module,
|
|
init_run,
|
|
log_metrics,
|
|
update_run_config,
|
|
update_summary,
|
|
)
|
|
from ..train_core import run_train
|
|
|
|
|
|
def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list[str]:
|
|
tags = [
|
|
kind,
|
|
spec.algorithm.name,
|
|
spec.runtime.backend,
|
|
"baseline" if spec.study.no_robust else "defended",
|
|
]
|
|
tags.extend([tag for tag in extra_tags if tag])
|
|
return tags
|
|
|
|
|
|
def _print_local_metrics(metrics: dict[str, Any]) -> None:
|
|
print(json.dumps(metrics, indent=2))
|
|
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
|
|
|
|
|
def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None:
|
|
if not events:
|
|
return
|
|
period = max(1, int(log_freq))
|
|
last_logged_step = -period
|
|
for event in sorted(
|
|
[evt for evt in events if isinstance(evt, dict)],
|
|
key=lambda evt: int(evt.get("train/global_step", 0)),
|
|
):
|
|
step = int(event.get("train/global_step", 0))
|
|
if step <= 0 or (step - last_logged_step) < period:
|
|
continue
|
|
log_metrics(event, step=step)
|
|
last_logged_step = step
|
|
|
|
|
|
def run_train_once(
|
|
spec: TrainSpec,
|
|
*,
|
|
project: str,
|
|
offline: bool,
|
|
no_wandb: bool,
|
|
kind: str,
|
|
scenario: str,
|
|
group: str | None,
|
|
extra_tags: Sequence[str],
|
|
) -> dict[str, Any]:
|
|
wandb = get_wandb_module()
|
|
if no_wandb or wandb is None:
|
|
result = run_train(spec)
|
|
_print_local_metrics(result.metrics)
|
|
return result.metrics
|
|
|
|
mode = "offline" if offline else "online"
|
|
tags = _tags_for_run(spec, kind, extra_tags)
|
|
metadata = run_metadata(
|
|
spec,
|
|
kind=kind,
|
|
scenario=scenario,
|
|
group=group,
|
|
tags=tags,
|
|
)
|
|
config = spec.to_flat_dict()
|
|
config.update(metadata)
|
|
name = run_name(spec, kind=kind, scenario=scenario)
|
|
init_run(
|
|
mode=mode,
|
|
project=project,
|
|
config=config,
|
|
name=name,
|
|
tags=tags,
|
|
group=group,
|
|
sweep_mode=False,
|
|
)
|
|
|
|
try:
|
|
result = run_train(spec)
|
|
_log_train_events(result.events, spec.runtime.log_freq)
|
|
metrics = result.metrics
|
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
|
log_metrics(metrics, step=step)
|
|
update_summary(metrics)
|
|
return metrics
|
|
finally:
|
|
finish_run()
|
|
|
|
|
|
def run_with_active_sweep_run(
|
|
spec: TrainSpec,
|
|
*,
|
|
kind: str,
|
|
scenario: str,
|
|
group: str | None,
|
|
extra_tags: Sequence[str],
|
|
) -> dict[str, Any]:
|
|
tags = _tags_for_run(spec, kind, extra_tags)
|
|
metadata = run_metadata(
|
|
spec,
|
|
kind=kind,
|
|
scenario=scenario,
|
|
group=group,
|
|
tags=tags,
|
|
)
|
|
update_run_config({**spec.to_flat_dict(), **metadata})
|
|
result = run_train(spec)
|
|
_log_train_events(result.events, spec.runtime.log_freq)
|
|
metrics = result.metrics
|
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
|
log_metrics(metrics, step=step)
|
|
update_summary(metrics)
|
|
return metrics
|