mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
fixing models for gcp
This commit is contained in:
@@ -6,6 +6,8 @@ import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
@@ -78,6 +80,7 @@ DEFAULT_CFG = {
|
||||
"jax_num_minibatches": 4,
|
||||
"jax_update_epochs": 4,
|
||||
"jax_anneal_lr": True,
|
||||
"checkpoint_interval": 10_000,
|
||||
}
|
||||
|
||||
|
||||
@@ -262,6 +265,16 @@ def build_model(cfg: dict, env):
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def _sb3_model_cls(algo: str):
|
||||
if algo == "ppo":
|
||||
return PPO
|
||||
if algo == "a2c":
|
||||
return A2C
|
||||
if algo == "dqn":
|
||||
return DQN
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
@@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
if not HAS_SB3:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||
from .lib.callbacks import MetricsCallback
|
||||
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
env = Monitor(env)
|
||||
eval_env = Monitor(eval_env)
|
||||
model = build_model(cfg, env)
|
||||
resume_step = 0
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||
if restored is not None:
|
||||
checkpoint_path, metadata = restored
|
||||
model = _sb3_model_cls(cfg["algo"]).load(
|
||||
checkpoint_path.as_posix(), env=env
|
||||
)
|
||||
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
|
||||
model.num_timesteps = max(
|
||||
int(getattr(model, "num_timesteps", 0)), resume_step
|
||||
)
|
||||
|
||||
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
|
||||
cbs.append(
|
||||
CheckpointArtifactCallback(
|
||||
cfg,
|
||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||
)
|
||||
)
|
||||
cbs.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
@@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
verbose=0,
|
||||
)
|
||||
)
|
||||
model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs)
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
if remaining_steps > 0:
|
||||
model.learn(
|
||||
total_timesteps=remaining_steps,
|
||||
callback=cbs,
|
||||
reset_num_timesteps=False,
|
||||
)
|
||||
|
||||
model_path = Path(cfg["model_dir"])
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save(str(model_path / f"phantom_{cfg['algo']}"))
|
||||
@@ -413,6 +456,7 @@ def main():
|
||||
p.add_argument("--jax-num-minibatches", type=int)
|
||||
p.add_argument("--jax-update-epochs", type=int)
|
||||
p.add_argument("--jax-anneal-lr", type=str)
|
||||
p.add_argument("--checkpoint-interval", type=int)
|
||||
p.add_argument("--sweep-agent", action="store_true")
|
||||
p.add_argument("--sweep-id", type=str)
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
@@ -441,6 +485,7 @@ def main():
|
||||
"jax_num_steps": args.jax_num_steps,
|
||||
"jax_num_minibatches": args.jax_num_minibatches,
|
||||
"jax_update_epochs": args.jax_update_epochs,
|
||||
"checkpoint_interval": args.checkpoint_interval,
|
||||
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||
if args.jax_anneal_lr is not None
|
||||
else None,
|
||||
|
||||
Reference in New Issue
Block a user