cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -4,9 +4,7 @@ import json
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
from ..telemetry.wandb import get_wandb_module
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
from ..lib.callbacks import MetricsCallback
from .common import evaluate, make_env
@@ -52,21 +50,6 @@ def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
return kwargs
def _sb3_model_cls(algo: str):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def build_model(cfg: Mapping[str, Any], env: Any):
try:
from stable_baselines3 import A2C, DQN, PPO
@@ -132,29 +115,7 @@ def build_model(cfg: Mapping[str, Any], env: Any):
raise ValueError(f"unsupported algo '{algo}'")
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return model
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is None:
return model
checkpoint_path, metadata = restored
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
checkpoint_path.as_posix(),
env=env,
)
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
return resumed
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
@@ -182,15 +143,10 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
except Exception:
pass
model = _maybe_resume_model(cfg, env, model)
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
callbacks.append(
CheckpointArtifactCallback(
dict(cfg),
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
metrics_callback = MetricsCallback(
log_histograms=False, log_freq=int(cfg["log_freq"])
)
callbacks = [metrics_callback]
callbacks.append(
EvalCallback(
eval_env,
@@ -215,13 +171,14 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
metrics: dict[str, float | int | str] = evaluate(
metrics: dict[str, Any] = evaluate(
model,
eval_env,
int(cfg["eval_episodes"]),
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["_train_events"] = list(metrics_callback.events)
env.close()
eval_env.close()