Files
PHANTOM/engine/train.py

522 lines
16 KiB
Python

from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import numpy as np
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
try:
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
HAS_SB3 = True
except ImportError:
HAS_SB3 = False
from .jax import JAX_AVAILABLE
DEFAULT_CFG = {
"project": "phantom-pricing",
"algo": "ppo",
"seed": 42,
"total_timesteps": 50_000,
"eval_episodes": 5,
"eval_freq": 1_000,
"log_freq": 100,
"revenue_weight": 0.01,
"n_products": 10,
"N": 100,
"alpha": 0.3,
"lambda_coi": 0.2,
"robust_radius": 0.15,
"robust_points": 5,
"info_value": 1.0,
"price_low": 10.0,
"price_high": 150.0,
"action_levels": 9,
"action_scale_low": 0.8,
"action_scale_high": 1.2,
"learning_rate": 3e-4,
"gamma": 0.99,
"buffer_size": 50_000,
"batch_size": 256,
"tau": 0.005,
"train_freq": 1,
"learning_starts": 1_000,
"target_update_interval": 1_000,
"exploration_fraction": 0.2,
"exploration_final_eps": 0.05,
"n_steps": 2_048,
"n_epochs": 10,
"gae_lambda": 0.95,
"clip_range": 0.2,
"ent_coef": 0.0,
"q_lr": 0.1,
"eps_start": 1.0,
"eps_end": 0.05,
"eps_decay": 0.9995,
"model_dir": "engine/models",
"arch": "small",
"activation": "relu",
"q_bins": 6,
"max_steps": 100,
"margin_floor": 0.05,
"margin_floor_patience": 5,
"use_jax": False,
"jax_num_envs": 16,
"jax_num_steps": 128,
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
"checkpoint_interval": 10_000,
}
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _cfg(raw: dict | None = None) -> dict:
cfg = dict(DEFAULT_CFG)
if raw:
cfg.update({k: v for k, v in raw.items() if v is not None})
cfg["algo"] = str(cfg["algo"]).lower()
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
os.environ.get("PHANTOM_USE_JAX")
)
return cfg
def _wandb_cfg_dict() -> dict:
return (
{k: wandb.config[k] for k in wandb.config.keys()}
if HAS_WANDB and wandb.run
else {}
)
def make_env(cfg: dict):
from gymnasium.wrappers import FlattenObservation
from .wrapper import PHANTOM
from .lib.wrappers import EconomicMetricsWrapper
env = PHANTOM(
n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]),
N=int(cfg["N"]),
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
lambda_coi=float(cfg["lambda_coi"]),
robust_radius=float(cfg["robust_radius"]),
robust_points=int(cfg["robust_points"]),
info_value=float(cfg["info_value"]),
action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None,
)
env = EconomicMetricsWrapper(env)
env = FlattenObservation(env)
return env
def _net_arch(name) -> list[int]:
presets = {
"tiny": [32, 32],
"small": [64, 64],
"medium": [128, 128],
"large": [256, 256],
}
if isinstance(name, (list, tuple)):
return [int(v) for v in name]
s = str(name).lower().strip()
if s in presets:
return presets[s]
if "x" in s:
try:
vals = [int(v) for v in s.split("x") if v]
return vals if vals else presets["small"]
except ValueError:
return presets["small"]
return presets["small"]
def _activation(name):
try:
import torch.nn as nn
except ImportError:
return None
return {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
}.get(str(name).lower().strip(), nn.ReLU)
def _policy_kwargs(cfg: dict) -> dict:
kw = {"net_arch": _net_arch(cfg.get("arch", "small"))}
act = _activation(cfg.get("activation", "relu"))
if act is not None:
kw["activation_fn"] = act
return kw
def _action(agent, obs, deterministic: bool = True):
out = agent.predict(obs, deterministic=deterministic)
a = out[0] if isinstance(out, tuple) else out
if isinstance(a, np.ndarray) and a.size == 1:
return int(a.reshape(-1)[0])
return a
def evaluate(agent, env, episodes: int) -> dict:
rewards, revenues = [], []
for _ in range(int(episodes)):
obs, _ = env.reset()
done, ep_r, ep_rev = False, 0.0, 0.0
while not done:
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
done = term or trunc
ep_r += float(reward)
ep_rev += float(
info.get("economics", {}).get("revenue", info.get("revenue", 0.0))
)
rewards.append(ep_r)
revenues.append(ep_rev)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
def build_model(cfg: dict, env):
algo = cfg["algo"]
policy_kwargs = _policy_kwargs(cfg)
if algo == "sac":
raise ValueError("sac is not supported with the discrete core env")
if algo == "ppo":
return PPO(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
n_steps=int(cfg["n_steps"]),
batch_size=int(cfg["batch_size"]),
n_epochs=int(cfg["n_epochs"]),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
clip_range=float(cfg["clip_range"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "a2c":
return A2C(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
n_steps=max(5, int(cfg["n_steps"]) // 32),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "dqn":
return DQN(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
buffer_size=int(cfg["buffer_size"]),
batch_size=int(cfg["batch_size"]),
gamma=float(cfg["gamma"]),
train_freq=int(cfg["train_freq"]),
learning_starts=int(cfg["learning_starts"]),
target_update_interval=int(cfg["target_update_interval"]),
exploration_fraction=float(cfg["exploration_fraction"]),
exploration_final_eps=float(cfg["exploration_final_eps"]),
)
raise ValueError(f"unsupported algo '{algo}'")
def _sb3_model_cls(algo: str):
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
from .lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
env = make_env(cfg)
eval_env = make_env(cfg)
agent = EventQTable(
env.action_space.n,
int(cfg["n_products"]),
(float(cfg["price_low"]), float(cfg["price_high"])),
lr=float(cfg["q_lr"]),
gamma=float(cfg["gamma"]),
n_bins=int(cfg["q_bins"]),
)
eps = float(cfg["eps_start"])
obs, _ = env.reset(seed=int(cfg["seed"]))
for t in range(int(cfg["total_timesteps"])):
a, s = agent.act(obs, eps)
nxt, reward, term, trunc, info = env.step(a)
done = term or trunc
agent.update(s, a, float(reward), agent.encode(nxt), done)
eps = max(float(cfg["eps_end"]), eps * float(cfg["eps_decay"]))
if HAS_WANDB and wandb.run and (t + 1) % int(cfg["log_freq"]) == 0:
econ = info.get("economics", {})
wandb.log(
{
"train/reward": float(reward),
"train/revenue": float(econ.get("revenue", 0.0)),
"train/epsilon": float(eps),
},
step=t + 1,
)
obs = env.reset()[0] if done else nxt
metrics = evaluate(agent, eval_env, int(cfg["eval_episodes"]))
metrics["train/global_step"] = int(cfg["total_timesteps"])
env.close()
eval_env.close()
return agent, metrics
def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
env = make_env(cfg)
eval_env = make_env(cfg)
env = Monitor(env)
eval_env = Monitor(eval_env)
model = build_model(cfg, env)
resume_step = 0
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is not None:
checkpoint_path, metadata = restored
model = _sb3_model_cls(cfg["algo"]).load(
checkpoint_path.as_posix(), env=env
)
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
model.num_timesteps = max(
int(getattr(model, "num_timesteps", 0)), resume_step
)
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
cbs.append(
CheckpointArtifactCallback(
cfg,
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
cbs.append(
EvalCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
deterministic=True,
verbose=0,
)
)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=cbs,
reset_num_timesteps=False,
)
model_path = Path(cfg["model_dir"])
model_path.mkdir(parents=True, exist_ok=True)
model.save(str(model_path / f"phantom_{cfg['algo']}"))
metrics = evaluate(model, eval_env, int(cfg["eval_episodes"]))
metrics["train/global_step"] = int(model.num_timesteps)
env.close()
eval_env.close()
return model, metrics
def train_once(cfg: dict) -> dict:
algo = cfg["algo"]
if cfg.get("use_jax"):
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
_, metrics = train_jax(cfg)
elif algo == "qtable":
_, metrics = train_qtable(cfg)
else:
_, metrics = train_sb3(cfg)
metrics["sweep/score"] = float(
metrics["eval/reward"] + float(cfg["revenue_weight"]) * metrics["eval/revenue"]
)
return metrics
def run_wandb(
project: str, overrides: dict, mode: str = "online", sweep_mode: bool = False
) -> dict:
if not HAS_WANDB:
raise ImportError("wandb is required for sweep runs")
init_kwargs = {"mode": mode}
if sweep_mode:
run = wandb.init(**init_kwargs)
else:
run = wandb.init(project=project, config=overrides, **init_kwargs)
try:
cfg = _cfg(_wandb_cfg_dict())
if sweep_mode:
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
return metrics
finally:
if wandb.run is not None:
wandb.finish()
def run_local(overrides: dict) -> dict:
cfg = _cfg(overrides)
metrics = train_once(cfg)
print(json.dumps(metrics, indent=2))
return metrics
def main():
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
p.add_argument("--project", default=DEFAULT_CFG["project"])
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
p.add_argument("--total-timesteps", type=int)
p.add_argument("--alpha", type=float)
p.add_argument("--n-products", type=int)
p.add_argument("--lambda-coi", type=float)
p.add_argument("--robust-radius", type=float)
p.add_argument("--robust-points", type=int)
p.add_argument("--learning-rate", type=float)
p.add_argument("--gamma", type=float)
p.add_argument("--revenue-weight", type=float)
p.add_argument("--max-steps", type=int)
p.add_argument("--margin-floor", type=float)
p.add_argument("--margin-floor-patience", type=int)
p.add_argument("--arch", type=str)
p.add_argument("--activation", type=str)
p.add_argument("--jax", action="store_true")
p.add_argument("--jax-num-envs", type=int)
p.add_argument("--jax-num-steps", type=int)
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--checkpoint-interval", type=int)
p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0)
p.add_argument("--offline", action="store_true")
p.add_argument("--no-wandb", action="store_true")
args = p.parse_args()
overrides = {
"algo": args.algo,
"total_timesteps": args.total_timesteps,
"alpha": args.alpha,
"n_products": args.n_products,
"lambda_coi": args.lambda_coi,
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"revenue_weight": args.revenue_weight,
"max_steps": args.max_steps,
"margin_floor": args.margin_floor,
"margin_floor_patience": args.margin_floor_patience,
"arch": args.arch,
"activation": args.activation,
"use_jax": args.jax,
"jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"checkpoint_interval": args.checkpoint_interval,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,
}
overrides = {k: v for k, v in overrides.items() if v is not None}
if args.sweep_agent:
if args.no_wandb:
raise ValueError("sweep agent requires wandb")
if not args.sweep_id:
raise ValueError("--sweep-id is required with --sweep-agent")
mode = "offline" if args.offline else "online"
wandb.agent(
args.sweep_id,
function=lambda: run_wandb(
args.project, overrides, mode=mode, sweep_mode=True
),
count=args.count if args.count > 0 else None,
)
return
if args.no_wandb or not HAS_WANDB:
run_local(overrides)
return
run_wandb(args.project, overrides, mode="offline" if args.offline else "online")
if __name__ == "__main__":
main()