updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import MetricsCallback
from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
from .common import evaluate, make_env
@@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any):
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc
@@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
pass
metrics_callback = MetricsCallback(
log_histograms=False,
log_histograms=True,
log_freq=int(cfg["log_freq"]),
hist_freq=int(cfg.get("hist_freq", 500)),
step_offset=int(cfg.get("wandb_step_offset", 0)),
)
callbacks = [metrics_callback]
callbacks.append(
EvalCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
deterministic=True,
verbose=0,
)
eval_callback = EvalMetricsCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
step_offset=int(cfg.get("wandb_step_offset", 0)),
deterministic=True,
verbose=0,
)
callbacks = [metrics_callback, eval_callback]
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
@@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
artifact_name = checkpoint_artifact_name(
cfg,
backend="sb3",
sweep_id=os.getenv("WANDB_SWEEP_ID"),
)
artifact_logged = False
try:
artifact_logged = bool(
log_checkpoint_file(
artifact_name,
file_path=model_path.with_suffix(".zip"),
artifact_file_name="model.zip",
metadata={
"algo": str(cfg.get("algo", "ppo")),
"backend": "sb3",
"seed": int(cfg.get("seed", 0)),
"step": int(getattr(model, "num_timesteps", 0)),
},
)
)
except Exception:
artifact_logged = False
metrics: dict[str, Any] = evaluate(
model,
eval_env,
@@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["_train_events"] = list(metrics_callback.events)
metrics["model/artifact_name"] = str(artifact_name)
metrics["model/artifact_logged"] = float(artifact_logged)
metrics["_train_events"] = sorted(
[*metrics_callback.events, *eval_callback.events],
key=lambda event: int(event.get("train/global_step", 0)),
)
env.close()
eval_env.close()