Files
PHANTOM/engine/backends/sb3.py

218 lines
6.8 KiB
Python

from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
from .common import evaluate, make_env
def _net_arch(name: Any) -> list[int]:
presets = {
"tiny": [32, 32],
"small": [64, 64],
"medium": [128, 128],
"large": [256, 256],
}
if isinstance(name, (list, tuple)):
return [int(v) for v in name]
raw = str(name).lower().strip()
if raw in presets:
return presets[raw]
if "x" in raw:
try:
parsed = [int(v) for v in raw.split("x") if v]
return parsed if parsed else presets["small"]
except ValueError:
return presets["small"]
return presets["small"]
def _activation(name: Any):
try:
import torch.nn as nn
except ImportError:
return None
return {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
}.get(str(name).lower().strip(), nn.ReLU)
def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
kwargs: dict[str, Any] = {"net_arch": _net_arch(cfg.get("arch", "small"))}
activation = _activation(cfg.get("activation", "relu"))
if activation is not None:
kwargs["activation_fn"] = activation
return kwargs
def build_model(cfg: Mapping[str, Any], env: Any):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
algo = str(cfg["algo"])
policy_kwargs = _policy_kwargs(cfg)
device = str(cfg.get("device", "auto"))
seed = int(cfg["seed"])
if algo == "sac":
raise ValueError("sac is not supported with the discrete core env")
if algo == "ppo":
return PPO(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=int(cfg["n_steps"]),
batch_size=int(cfg["batch_size"]),
n_epochs=int(cfg["n_epochs"]),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
clip_range=float(cfg["clip_range"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "a2c":
return A2C(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=max(5, int(cfg["n_steps"]) // 32),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "dqn":
return DQN(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
buffer_size=int(cfg["buffer_size"]),
batch_size=int(cfg["batch_size"]),
gamma=float(cfg["gamma"]),
train_freq=int(cfg["train_freq"]),
learning_starts=int(cfg["learning_starts"]),
target_update_interval=int(cfg["target_update_interval"]),
exploration_fraction=float(cfg["exploration_fraction"]),
exploration_final_eps=float(cfg["exploration_final_eps"]),
)
raise ValueError(f"unsupported algo '{algo}'")
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
try:
from stable_baselines3.common.monitor import Monitor
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc
env = Monitor(make_env(cfg))
eval_env = Monitor(make_env(cfg))
model = build_model(cfg, env)
try:
import torch
print(
"PHANTOM_DEVICE: "
+ json.dumps(
{
"requested": str(cfg.get("device", "auto")),
"torch_cuda_available": bool(torch.cuda.is_available()),
"torch_device_count": int(torch.cuda.device_count()),
"sb3_device": str(getattr(model, "device", "unknown")),
}
)
)
except Exception:
pass
metrics_callback = MetricsCallback(
log_histograms=True,
log_freq=int(cfg["log_freq"]),
hist_freq=int(cfg.get("hist_freq", 500)),
step_offset=int(cfg.get("wandb_step_offset", 0)),
)
eval_callback = EvalMetricsCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
step_offset=int(cfg.get("wandb_step_offset", 0)),
deterministic=True,
verbose=0,
)
callbacks = [metrics_callback, eval_callback]
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=callbacks,
reset_num_timesteps=False,
)
model_dir = Path(str(cfg["model_dir"]))
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
artifact_name = checkpoint_artifact_name(
cfg,
backend="sb3",
sweep_id=os.getenv("WANDB_SWEEP_ID"),
)
artifact_logged = False
try:
artifact_logged = bool(
log_checkpoint_file(
artifact_name,
file_path=model_path.with_suffix(".zip"),
artifact_file_name="model.zip",
metadata={
"algo": str(cfg.get("algo", "ppo")),
"backend": "sb3",
"seed": int(cfg.get("seed", 0)),
"step": int(getattr(model, "num_timesteps", 0)),
},
)
)
except Exception:
artifact_logged = False
metrics: dict[str, Any] = evaluate(
model,
eval_env,
int(cfg["eval_episodes"]),
cfg=cfg,
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["model/artifact_name"] = str(artifact_name)
metrics["model/artifact_logged"] = float(artifact_logged)
metrics["_train_events"] = sorted(
[*metrics_callback.events, *eval_callback.events],
key=lambda event: int(event.get("train/global_step", 0)),
)
env.close()
eval_env.close()
return model, metrics