fixing models for gcp

This commit is contained in:
2026-02-17 16:54:55 +01:00
parent 802f31b4a1
commit 9acc998cc9
5 changed files with 497 additions and 193 deletions

View File

@@ -6,6 +6,8 @@ import os
from pathlib import Path
import numpy as np
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try:
import wandb
@@ -78,6 +80,7 @@ DEFAULT_CFG = {
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
"checkpoint_interval": 10_000,
}
@@ -262,6 +265,16 @@ def build_model(cfg: dict, env):
raise ValueError(f"unsupported algo '{algo}'")
def _sb3_model_cls(algo: str):
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
from .lib.discrete import EventQTable
@@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import MetricsCallback
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
env = make_env(cfg)
eval_env = make_env(cfg)
env = Monitor(env)
eval_env = Monitor(eval_env)
model = build_model(cfg, env)
resume_step = 0
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is not None:
checkpoint_path, metadata = restored
model = _sb3_model_cls(cfg["algo"]).load(
checkpoint_path.as_posix(), env=env
)
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
model.num_timesteps = max(
int(getattr(model, "num_timesteps", 0)), resume_step
)
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
cbs.append(
CheckpointArtifactCallback(
cfg,
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
cbs.append(
EvalCallback(
eval_env,
@@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
verbose=0,
)
)
model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=cbs,
reset_num_timesteps=False,
)
model_path = Path(cfg["model_dir"])
model_path.mkdir(parents=True, exist_ok=True)
model.save(str(model_path / f"phantom_{cfg['algo']}"))
@@ -413,6 +456,7 @@ def main():
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--checkpoint-interval", type=int)
p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0)
@@ -441,6 +485,7 @@ def main():
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"checkpoint_interval": args.checkpoint_interval,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,