From 73246d7dd8468f64f1b47f826541034c6c97fd52 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Sun, 8 Mar 2026 18:30:53 +0100 Subject: [PATCH] refactoring training spc setup and benchmarking --- .env.sweep.example | 6 +- .gitignore | 1 + Makefile | 17 +- engine/backends/__init__.py | 1 + engine/backends/common.py | 81 +++ engine/backends/jax.py | 18 + engine/backends/qtable.py | 53 ++ engine/backends/sb3.py | 228 ++++++++ engine/benchmark.py | 456 ++++++++++++++++ engine/jax/train.py | 40 +- engine/lib/__init__.py | 75 +-- engine/lib/callbacks.py | 24 +- engine/lib/tiers.py | 101 ++++ engine/orchestrators/__init__.py | 5 + engine/orchestrators/benchmark.py | 7 + engine/orchestrators/sweep_agent.py | 60 +++ engine/orchestrators/train.py | 129 +++++ engine/project.json | 20 + engine/spec.py | 340 ++++++++++++ engine/studies/local_comparison.py | 136 +++++ engine/sweeps/model_mix.yaml | 2 +- engine/sweeps/models_only.yaml | 2 +- engine/sweeps/sac_tune.yaml | 2 +- engine/sweeps/small_arch_compare.yaml | 2 +- engine/sweeps/tpu_jax.yaml | 2 +- engine/sweeps/tpu_pod.yaml | 2 +- engine/telemetry/__init__.py | 23 + engine/telemetry/metrics.py | 57 ++ engine/telemetry/wandb.py | 98 ++++ engine/train.py | 722 +++++++------------------- engine/train_core.py | 40 ++ engine/wrapper.py | 2 +- nx.json | 3 + package.json | 1 + scripts/nx_research.sh | 31 +- scripts/tpu_vm_sweep_agent.py | 6 +- 36 files changed, 2180 insertions(+), 613 deletions(-) create mode 100644 engine/backends/__init__.py create mode 100644 engine/backends/common.py create mode 100644 engine/backends/jax.py create mode 100644 engine/backends/qtable.py create mode 100644 engine/backends/sb3.py create mode 100644 engine/benchmark.py create mode 100644 engine/lib/tiers.py create mode 100644 engine/orchestrators/__init__.py create mode 100644 engine/orchestrators/benchmark.py create mode 100644 engine/orchestrators/sweep_agent.py create mode 100644 engine/orchestrators/train.py create mode 100644 engine/spec.py create mode 100644 engine/studies/local_comparison.py create mode 100644 engine/telemetry/__init__.py create mode 100644 engine/telemetry/metrics.py create mode 100644 engine/telemetry/wandb.py create mode 100644 engine/train_core.py diff --git a/.env.sweep.example b/.env.sweep.example index 1cfb168..680f9e7 100644 --- a/.env.sweep.example +++ b/.env.sweep.example @@ -3,7 +3,7 @@ # Required for wandb runs and sweep agent workers. WANDB_API_KEY= WANDB_ENTITY= -WANDB_PROJECT=phantom-pricing +WANDB_PROJECT=capstone # Required for private repo bootstrap workers. GITHUB_TOKEN= @@ -16,3 +16,7 @@ GITHUB_TOKEN= # AGENT_COUNT=0 # AGENT_LOOP=1 # RETRY_SECONDS=20 + +# Optional local benchmark defaults. +# LOCAL_BENCHMARK_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu +# BENCHMARK_AGENT_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3,0.6 --episodes 5 diff --git a/.gitignore b/.gitignore index a03acca..95fc1cf 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,7 @@ sim/case/thesis_simplified/runs*/ # model binaries engine/models/*.zip +engine/studies/results/* *.zip # wandb local state diff --git a/Makefile b/Makefile index 90969f5..22a67db 100644 --- a/Makefile +++ b/Makefile @@ -13,9 +13,11 @@ NX := npx nx SWEEP_ENV_FILE ?= .env.sweep WANDB_ENTITY ?= -WANDB_PROJECT ?= phantom-pricing +WANDB_PROJECT ?= capstone SWEEP_ID ?= LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000 +LOCAL_BENCHMARK_ARGS ?= --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu +BENCHMARK_AGENT_ARGS ?= AGENT_COUNT ?= 0 REPO_URL ?= @@ -36,7 +38,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" + @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" @echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish" @echo "" @echo "Build general public version:" @@ -45,6 +47,9 @@ help: @echo "Local wandb run:" @echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'" @echo "" + @echo "Local benchmark run:" + @echo " make benchmark LOCAL_BENCHMARK_ARGS='--tiers static,surge,linear --alpha-values 0.0,0.3 --episodes 3 --no-wandb'" + @echo "" @echo "Local sweep agent from this repo:" @echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5" @echo "" @@ -104,6 +109,14 @@ install: train: @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train +.PHONY: benchmark +benchmark: + @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_BENCHMARK_ARGS="$(LOCAL_BENCHMARK_ARGS)" $(NX) run research:benchmark + +.PHONY: benchmark.agent +benchmark.agent: + @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" BENCHMARK_AGENT_ARGS="$(BENCHMARK_AGENT_ARGS)" $(NX) run research:benchmark-agent + .PHONY: train.agent train.agent: @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-agent diff --git a/engine/backends/__init__.py b/engine/backends/__init__.py new file mode 100644 index 0000000..014450a --- /dev/null +++ b/engine/backends/__init__.py @@ -0,0 +1 @@ +__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"] diff --git a/engine/backends/common.py b/engine/backends/common.py new file mode 100644 index 0000000..9b916ab --- /dev/null +++ b/engine/backends/common.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any, Mapping + +import numpy as np + + +def make_env(cfg: Mapping[str, Any]): + from gymnasium.wrappers import FlattenObservation + + from ..lib.wrappers import EconomicMetricsWrapper + from ..wrapper import PHANTOM + + env = PHANTOM( + n_products=int(cfg["n_products"]), + alpha=float(cfg["alpha"]), + N=int(cfg["N"]), + price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])), + lambda_coi=float(cfg["lambda_coi"]), + robust_radius=float(cfg["robust_radius"]), + robust_points=int(cfg["robust_points"]), + info_value=float(cfg["info_value"]), + action_levels=int(cfg["action_levels"]), + action_scale_low=float(cfg["action_scale_low"]), + action_scale_high=float(cfg["action_scale_high"]), + max_steps=int(cfg.get("max_steps", 100)), + margin_floor=float(cfg.get("margin_floor", 0.05)), + margin_floor_patience=int(cfg.get("margin_floor_patience", 5)), + render_mode=None, + ) + env = EconomicMetricsWrapper(env) + return FlattenObservation(env) + + +def _action(agent: Any, obs: Any, deterministic: bool = True): + out = agent.predict(obs, deterministic=deterministic) + action = out[0] if isinstance(out, tuple) else out + if isinstance(action, np.ndarray) and action.size == 1: + return int(action.reshape(-1)[0]) + return action + + +def evaluate(agent: Any, env: Any, episodes: int) -> dict[str, float]: + rewards: list[float] = [] + revenues: list[float] = [] + margins: list[float] = [] + coi_levels: list[float] = [] + + for _ in range(int(episodes)): + obs, _ = env.reset() + done = False + ep_reward = 0.0 + ep_revenue = 0.0 + ep_margin = 0.0 + ep_coi = 0.0 + steps = 0 + + while not done: + obs, reward, term, trunc, info = env.step(_action(agent, obs, True)) + done = bool(term or trunc) + econ = info.get("economics", {}) + ep_reward += float(reward) + ep_revenue += float(econ.get("revenue", info.get("revenue", 0.0))) + ep_margin += float(econ.get("margin", 0.0)) + ep_coi += float(econ.get("coi_level", 0.0)) + steps += 1 + + rewards.append(ep_reward) + revenues.append(ep_revenue) + denom = max(steps, 1) + margins.append(ep_margin / denom) + coi_levels.append(ep_coi / denom) + + return { + "eval/reward_mean": float(np.mean(rewards)) if rewards else 0.0, + "eval/reward_std": float(np.std(rewards)) if rewards else 0.0, + "eval/revenue_mean": float(np.mean(revenues)) if revenues else 0.0, + "eval/revenue_std": float(np.std(revenues)) if revenues else 0.0, + "eval/margin_mean": float(np.mean(margins)) if margins else 0.0, + "eval/coi_level_mean": float(np.mean(coi_levels)) if coi_levels else 0.0, + } diff --git a/engine/backends/jax.py b/engine/backends/jax.py new file mode 100644 index 0000000..980c01f --- /dev/null +++ b/engine/backends/jax.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any, Mapping + +from ..jax import JAX_AVAILABLE + + +def train_jax_backend( + cfg: Mapping[str, Any], +) -> tuple[dict[str, Any], dict[str, float | int | str]]: + if not JAX_AVAILABLE: + raise ImportError( + "JAX backend requested but JAX is not installed. " + "Install engine/jax/requirements.txt and jax[tpu] for TPU runs." + ) + from ..jax.train import train_jax + + return train_jax(dict(cfg)) diff --git a/engine/backends/qtable.py b/engine/backends/qtable.py new file mode 100644 index 0000000..9a6e3fe --- /dev/null +++ b/engine/backends/qtable.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Any, Mapping + +import numpy as np + +from .common import evaluate, make_env + + +def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]: + from ..lib.discrete import EventQTable + + np.random.seed(int(cfg["seed"])) + env = make_env(cfg) + eval_env = make_env(cfg) + agent = EventQTable( + env.action_space.n, + int(cfg["n_products"]), + (float(cfg["price_low"]), float(cfg["price_high"])), + lr=float(cfg["q_lr"]), + gamma=float(cfg["gamma"]), + n_bins=int(cfg["q_bins"]), + ) + + total_reward = 0.0 + total_revenue = 0.0 + steps = 0 + epsilon = float(cfg["eps_start"]) + obs, _ = env.reset(seed=int(cfg["seed"])) + + for _ in range(int(cfg["total_timesteps"])): + action, state = agent.act(obs, epsilon) + nxt, reward, term, trunc, info = env.step(action) + done = bool(term or trunc) + agent.update(state, action, float(reward), agent.encode(nxt), done) + + total_reward += float(reward) + total_revenue += float(info.get("economics", {}).get("revenue", 0.0)) + steps += 1 + epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"])) + obs = env.reset()[0] if done else nxt + + metrics: dict[str, float | int] = { + "train/reward_mean": total_reward / max(steps, 1), + "train/revenue_mean": total_revenue / max(steps, 1), + "train/epsilon": float(epsilon), + "train/global_step": int(cfg["total_timesteps"]), + } + metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"]))) + + env.close() + eval_env.close() + return agent, metrics diff --git a/engine/backends/sb3.py b/engine/backends/sb3.py new file mode 100644 index 0000000..ad17e0b --- /dev/null +++ b/engine/backends/sb3.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Mapping + +from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback +from ..telemetry.wandb import get_wandb_module +from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint +from .common import evaluate, make_env + + +def _net_arch(name: Any) -> list[int]: + presets = { + "tiny": [32, 32], + "small": [64, 64], + "medium": [128, 128], + "large": [256, 256], + } + if isinstance(name, (list, tuple)): + return [int(v) for v in name] + raw = str(name).lower().strip() + if raw in presets: + return presets[raw] + if "x" in raw: + try: + parsed = [int(v) for v in raw.split("x") if v] + return parsed if parsed else presets["small"] + except ValueError: + return presets["small"] + return presets["small"] + + +def _activation(name: Any): + try: + import torch.nn as nn + except ImportError: + return None + return { + "relu": nn.ReLU, + "tanh": nn.Tanh, + "elu": nn.ELU, + "leaky_relu": nn.LeakyReLU, + }.get(str(name).lower().strip(), nn.ReLU) + + +def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]: + kwargs: dict[str, Any] = {"net_arch": _net_arch(cfg.get("arch", "small"))} + activation = _activation(cfg.get("activation", "relu")) + if activation is not None: + kwargs["activation_fn"] = activation + return kwargs + + +def _sb3_model_cls(algo: str): + try: + from stable_baselines3 import A2C, DQN, PPO + except ImportError as exc: + raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc + + if algo == "ppo": + return PPO + if algo == "a2c": + return A2C + if algo == "dqn": + return DQN + raise ValueError(f"unsupported algo '{algo}'") + + +def build_model(cfg: Mapping[str, Any], env: Any): + try: + from stable_baselines3 import A2C, DQN, PPO + except ImportError as exc: + raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc + + algo = str(cfg["algo"]) + policy_kwargs = _policy_kwargs(cfg) + device = str(cfg.get("device", "auto")) + seed = int(cfg["seed"]) + + if algo == "sac": + raise ValueError("sac is not supported with the discrete core env") + if algo == "ppo": + return PPO( + "MlpPolicy", + env, + verbose=1, + device=device, + policy_kwargs=policy_kwargs, + seed=seed, + learning_rate=float(cfg["learning_rate"]), + n_steps=int(cfg["n_steps"]), + batch_size=int(cfg["batch_size"]), + n_epochs=int(cfg["n_epochs"]), + gamma=float(cfg["gamma"]), + gae_lambda=float(cfg["gae_lambda"]), + clip_range=float(cfg["clip_range"]), + ent_coef=float(cfg["ent_coef"]), + ) + if algo == "a2c": + return A2C( + "MlpPolicy", + env, + verbose=1, + device=device, + policy_kwargs=policy_kwargs, + seed=seed, + learning_rate=float(cfg["learning_rate"]), + n_steps=max(5, int(cfg["n_steps"]) // 32), + gamma=float(cfg["gamma"]), + gae_lambda=float(cfg["gae_lambda"]), + ent_coef=float(cfg["ent_coef"]), + ) + if algo == "dqn": + return DQN( + "MlpPolicy", + env, + verbose=1, + device=device, + policy_kwargs=policy_kwargs, + seed=seed, + learning_rate=float(cfg["learning_rate"]), + buffer_size=int(cfg["buffer_size"]), + batch_size=int(cfg["batch_size"]), + gamma=float(cfg["gamma"]), + train_freq=int(cfg["train_freq"]), + learning_starts=int(cfg["learning_starts"]), + target_update_interval=int(cfg["target_update_interval"]), + exploration_fraction=float(cfg["exploration_fraction"]), + exploration_final_eps=float(cfg["exploration_final_eps"]), + ) + raise ValueError(f"unsupported algo '{algo}'") + + +def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any): + wandb = get_wandb_module() + if wandb is None or wandb.run is None: + return model + + sweep_id = getattr(wandb.run, "sweep_id", None) + artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id) + checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip" + restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file) + if restored is None: + return model + + checkpoint_path, metadata = restored + resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load( + checkpoint_path.as_posix(), + env=env, + ) + resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0))) + resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step) + return resumed + + +def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]: + try: + from stable_baselines3.common.callbacks import EvalCallback + from stable_baselines3.common.monitor import Monitor + except ImportError as exc: + raise ImportError("stable-baselines3 is required for SB3 models") from exc + + env = Monitor(make_env(cfg)) + eval_env = Monitor(make_env(cfg)) + model = build_model(cfg, env) + + try: + import torch + + print( + "PHANTOM_DEVICE: " + + json.dumps( + { + "requested": str(cfg.get("device", "auto")), + "torch_cuda_available": bool(torch.cuda.is_available()), + "torch_device_count": int(torch.cuda.device_count()), + "sb3_device": str(getattr(model, "device", "unknown")), + } + ) + ) + except Exception: + pass + + model = _maybe_resume_model(cfg, env, model) + + callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))] + callbacks.append( + CheckpointArtifactCallback( + dict(cfg), + interval=int(cfg.get("checkpoint_interval", 10_000)), + ) + ) + callbacks.append( + EvalCallback( + eval_env, + eval_freq=int(cfg["eval_freq"]), + n_eval_episodes=int(cfg["eval_episodes"]), + deterministic=True, + verbose=0, + ) + ) + + target_steps = int(cfg["total_timesteps"]) + remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0))) + if remaining_steps > 0: + model.learn( + total_timesteps=remaining_steps, + callback=callbacks, + reset_num_timesteps=False, + ) + + model_dir = Path(str(cfg["model_dir"])) + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / f"phantom_{cfg['algo']}" + model.save(str(model_path)) + + metrics: dict[str, float | int | str] = evaluate( + model, + eval_env, + int(cfg["eval_episodes"]), + ) + metrics["train/global_step"] = int(model.num_timesteps) + metrics["model/path"] = str(model_path.with_suffix(".zip")) + + env.close() + eval_env.close() + return model, metrics diff --git a/engine/benchmark.py b/engine/benchmark.py new file mode 100644 index 0000000..65b6e47 --- /dev/null +++ b/engine/benchmark.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +import argparse +import json +import os +from datetime import datetime, UTC +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from .lib.tiers import LinearElasticityPolicy, StaticPolicy, SurgePolicy +from .spec import TrainSpec +from .telemetry.wandb import get_wandb_module + +wandb = get_wandb_module() +HAS_WANDB = wandb is not None + + +def _parse_list(raw: str) -> list[str]: + return [x.strip().lower() for x in str(raw).split(",") if x.strip()] + + +def _parse_float_list(raw: str) -> list[float]: + return [float(x.strip()) for x in str(raw).split(",") if x.strip()] + + +def _truthy(value: str | bool | None) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _action(policy, obs: np.ndarray): + out = policy.predict(obs, deterministic=True) + action = out[0] if isinstance(out, tuple) else out + if isinstance(action, np.ndarray) and action.size == 1: + return int(action.reshape(-1)[0]) + return int(action) + + +def _run_eval_episode(env, policy) -> dict: + obs, _ = env.reset() + done = False + total_reward = 0.0 + total_revenue = 0.0 + total_margin = 0.0 + total_coi = 0.0 + price_trace: list[float] = [] + step_count = 0 + + while not done: + action = _action(policy, obs) + obs, reward, term, trunc, info = env.step(action) + done = bool(term or trunc) + econ = info.get("economics", {}) + total_reward += float(reward) + total_revenue += float(econ.get("revenue", 0.0)) + total_margin += float(econ.get("margin", 0.0)) + total_coi += float(econ.get("coi_level", 0.0)) + prices = np.asarray(info.get("prices", []), dtype=np.float32) + if prices.size > 0: + price_trace.append(float(np.mean(prices))) + step_count += 1 + + denom = max(step_count, 1) + return { + "reward": total_reward, + "revenue": total_revenue, + "mean_margin": total_margin / denom, + "mean_coi": total_coi / denom, + "price_trace": price_trace, + } + + +def _build_tier(name: str, cfg: dict, alpha: float): + from .backends.common import make_env + from .backends.qtable import train_qtable + from .backends.sb3 import train_sb3 + + tier = name.lower().strip() + run_cfg = dict(cfg) + run_cfg["alpha"] = float(alpha) + + if tier == "static": + return StaticPolicy(int(run_cfg["action_levels"])) + + if tier == "surge": + return SurgePolicy( + n_actions=int(run_cfg["action_levels"]), + n_products=int(run_cfg["n_products"]), + ) + + if tier == "linear": + warmup_env = make_env(run_cfg) + policy = LinearElasticityPolicy( + n_actions=int(run_cfg["action_levels"]), + n_products=int(run_cfg["n_products"]), + price_low=float(run_cfg["price_low"]), + price_high=float(run_cfg["price_high"]), + ) + policy.fit( + warmup_env, + warmup_steps=int(run_cfg.get("linear_warmup_steps", 800)), + seed=int(run_cfg["seed"]), + ) + warmup_env.close() + return policy + + if tier == "qtable": + agent, _ = train_qtable(run_cfg) + return agent + + if tier in {"ppo", "a2c", "dqn"}: + run_cfg["algo"] = tier + agent, _ = train_sb3(run_cfg) + return agent + + raise ValueError(f"unsupported tier '{name}'") + + +def run_benchmark( + cfg: dict, tiers: list[str], alpha_values: list[float], n_episodes: int +): + from .backends.common import make_env + + rows: list[dict] = [] + traces: list[dict] = [] + + for alpha in alpha_values: + for tier_name in tiers: + policy = _build_tier(tier_name, cfg, alpha) + env = make_env({**cfg, "alpha": float(alpha)}) + eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))] + env.close() + + row = { + "tier": tier_name, + "alpha": float(alpha), + "episodes": int(n_episodes), + "mean_reward": float(np.mean([e["reward"] for e in eps])), + "mean_revenue": float(np.mean([e["revenue"] for e in eps])), + "mean_margin": float(np.mean([e["mean_margin"] for e in eps])), + "mean_coi": float(np.mean([e["mean_coi"] for e in eps])), + "std_revenue": float(np.std([e["revenue"] for e in eps])), + } + row["objective_score"] = ( + row["mean_reward"] + + float(cfg.get("revenue_weight", 0.01)) * row["mean_revenue"] + ) + rows.append(row) + + max_len = max((len(e["price_trace"]) for e in eps), default=0) + step_means = [] + for step in range(max_len): + vals = [ + e["price_trace"][step] for e in eps if step < len(e["price_trace"]) + ] + step_means.append(float(np.mean(vals)) if vals else np.nan) + traces.append( + { + "tier": tier_name, + "alpha": float(alpha), + "mean_price_trace": step_means, + } + ) + + if HAS_WANDB and wandb.run is not None: + wandb.log( + { + "study/alpha": float(alpha), + "eval/reward_mean": row["mean_reward"], + "eval/revenue_mean": row["mean_revenue"], + "eval/margin_mean": row["mean_margin"], + "objective/score": row["objective_score"], + "objective/coi_preserved": row["mean_coi"], + } + ) + + return pd.DataFrame(rows), traces + + +def _plot_outputs(df: pd.DataFrame, traces: list[dict], out_dir: Path, stamp: str): + fig1 = plt.figure(figsize=(11, 4.5)) + if "mode" in df.columns: + groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist()) + for tier, mode in groups: + sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha") + plt.plot( + sub["alpha"], + sub["mean_revenue"], + marker="o", + label=f"{tier}:{mode}", + ) + else: + for tier in sorted(df["tier"].unique()): + sub = df[df["tier"] == tier].sort_values("alpha") + plt.plot(sub["alpha"], sub["mean_revenue"], marker="o", label=tier) + plt.xlabel("contamination alpha") + plt.ylabel("mean episode revenue") + plt.title("Revenue under contamination") + plt.grid(alpha=0.3) + plt.legend() + fig1.tight_layout() + rev_path = out_dir / f"benchmark_revenue_{stamp}.png" + fig1.savefig(rev_path, dpi=220) + plt.close(fig1) + + fig2 = plt.figure(figsize=(11, 4.5)) + if "mode" in df.columns: + groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist()) + for tier, mode in groups: + sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha") + plt.plot( + sub["alpha"], + sub["mean_coi"], + marker="s", + label=f"{tier}:{mode}", + ) + else: + for tier in sorted(df["tier"].unique()): + sub = df[df["tier"] == tier].sort_values("alpha") + plt.plot(sub["alpha"], sub["mean_coi"], marker="s", label=tier) + plt.xlabel("contamination alpha") + plt.ylabel("mean COI level") + plt.title("COI preservation") + plt.grid(alpha=0.3) + plt.legend() + fig2.tight_layout() + coi_path = out_dir / f"benchmark_coi_{stamp}.png" + fig2.savefig(coi_path, dpi=220) + plt.close(fig2) + + focus_alpha = float(df["alpha"].min()) if not df.empty else 0.0 + alpha_traces = [t for t in traces if abs(float(t["alpha"]) - focus_alpha) < 1e-9] + fig3 = plt.figure(figsize=(11, 4.5)) + for item in alpha_traces: + xs = np.arange(len(item["mean_price_trace"])) + ys = np.asarray(item["mean_price_trace"], dtype=np.float32) + mode = item.get("mode") + label = f"{item['tier']}:{mode}" if mode is not None else str(item["tier"]) + plt.plot(xs, ys, label=label) + plt.xlabel("step") + plt.ylabel("mean price") + plt.title(f"Price evolution (alpha={focus_alpha:.2f})") + plt.grid(alpha=0.3) + plt.legend() + fig3.tight_layout() + price_path = out_dir / f"benchmark_price_trace_{stamp}.png" + fig3.savefig(price_path, dpi=220) + plt.close(fig3) + + return rev_path, coi_path, price_path + + +def _run_with_args(args): + compare_robust = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) + robust_modes = [False, True] if compare_robust else [bool(args.no_robust)] + + base_overrides = { + "seed": args.seed, + "total_timesteps": args.total_timesteps, + "n_products": args.n_products, + "N": args.N, + "lambda_coi": args.lambda_coi, + "robust_radius": args.robust_radius, + "robust_points": args.robust_points, + "price_low": args.price_low, + "price_high": args.price_high, + "action_levels": args.action_levels, + "action_scale_low": args.action_scale_low, + "action_scale_high": args.action_scale_high, + "max_steps": args.max_steps, + "learning_rate": args.learning_rate, + "batch_size": args.batch_size, + "n_steps": args.n_steps, + "linear_warmup_steps": args.linear_warmup_steps, + "device": args.device, + } + tiers = _parse_list(args.tiers) + alpha_values = _parse_float_list(args.alpha_values) + + all_frames: list[pd.DataFrame] = [] + all_traces: list[dict] = [] + for no_robust in robust_modes: + overrides = dict(base_overrides) + overrides["no_robust"] = bool(no_robust) + cfg = TrainSpec.from_flat( + {k: v for k, v in overrides.items() if v is not None} + ).to_flat_dict() + cfg["linear_warmup_steps"] = int(args.linear_warmup_steps) + df_mode, traces_mode = run_benchmark(cfg, tiers, alpha_values, args.episodes) + mode_label = "no_robust" if no_robust else "robust" + df_mode["mode"] = mode_label + for trace in traces_mode: + trace["mode"] = mode_label + all_frames.append(df_mode) + all_traces.extend(traces_mode) + + df = pd.concat(all_frames, ignore_index=True) if all_frames else pd.DataFrame() + traces = all_traces + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + stamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + csv_path = out_dir / f"benchmark_{stamp}.csv" + trace_path = out_dir / f"benchmark_traces_{stamp}.json" + df.to_csv(csv_path, index=False) + trace_path.write_text(json.dumps(traces, indent=2)) + rev_path, coi_path, price_path = _plot_outputs(df, traces, out_dir, stamp) + + if not df.empty: + best_idx = int(df["mean_revenue"].idxmax()) + best = df.iloc[best_idx] + print( + "BEST_TIER=" + + json.dumps( + { + "tier": best["tier"], + "mode": best.get("mode", "robust"), + "alpha": float(best["alpha"]), + "mean_revenue": float(best["mean_revenue"]), + "mean_coi": float(best["mean_coi"]), + } + ) + ) + print(f"BENCHMARK_CSV={csv_path}") + print(f"BENCHMARK_TRACES={trace_path}") + print(f"BENCHMARK_PLOT_REVENUE={rev_path}") + print(f"BENCHMARK_PLOT_COI={coi_path}") + print(f"BENCHMARK_PLOT_PRICE={price_path}") + + +def run_cli(raw_args: list[str] | None = None): + parser = argparse.ArgumentParser(description="PHANTOM benchmark orchestrator") + parser.add_argument("--project", default="capstone") + parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo") + parser.add_argument("--alpha-values", default="0.0,0.3,0.6") + parser.add_argument("--episodes", type=int, default=10) + parser.add_argument("--output-dir", default="engine/studies/results") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--total-timesteps", type=int, default=25_000) + parser.add_argument("--n-products", type=int, default=10) + parser.add_argument("--N", type=int, default=100) + parser.add_argument("--lambda-coi", type=float, default=0.2) + parser.add_argument("--robust-radius", type=float, default=0.15) + parser.add_argument("--robust-points", type=int, default=5) + parser.add_argument("--price-low", type=float, default=10.0) + parser.add_argument("--price-high", type=float, default=150.0) + parser.add_argument("--action-levels", type=int, default=9) + parser.add_argument("--action-scale-low", type=float, default=0.8) + parser.add_argument("--action-scale-high", type=float, default=1.2) + parser.add_argument("--max-steps", type=int, default=100) + parser.add_argument("--learning-rate", type=float, default=3e-4) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--n-steps", type=int, default=2048) + parser.add_argument("--linear-warmup-steps", type=int, default=800) + parser.add_argument("--device", type=str, default="auto") + parser.add_argument("--no-robust", action="store_true") + parser.add_argument("--no-wandb", action="store_true") + parser.add_argument("--offline", action="store_true") + parser.add_argument("--sweep-agent", action="store_true") + parser.add_argument("--sweep-id", type=str) + parser.add_argument("--count", type=int, default=0) + args = parser.parse_args(raw_args) + + if args.sweep_agent: + if args.no_wandb or not HAS_WANDB: + raise ValueError("sweep agent requires wandb") + if not args.sweep_id: + raise ValueError("--sweep-id is required with --sweep-agent") + + def _sweep_run(): + run = wandb.init(mode="offline" if args.offline else "online") + try: + key_to_attr = { + "tiers": "tiers", + "alpha_values": "alpha_values", + "episodes": "episodes", + "total_timesteps": "total_timesteps", + "lambda_coi": "lambda_coi", + "robust_radius": "robust_radius", + "robust_points": "robust_points", + "learning_rate": "learning_rate", + "batch_size": "batch_size", + "n_steps": "n_steps", + "no_robust": "no_robust", + "device": "device", + } + for key in ( + "tiers", + "alpha_values", + "episodes", + "total_timesteps", + "lambda_coi", + "robust_radius", + "robust_points", + "learning_rate", + "batch_size", + "n_steps", + "no_robust", + "device", + ): + if key in wandb.config: + setattr(args, key_to_attr[key], wandb.config[key]) + _run_with_args(args) + finally: + if run is not None: + wandb.finish() + + wandb.agent( + args.sweep_id, + function=_sweep_run, + count=args.count if args.count > 0 else None, + ) + return + + if args.no_wandb or not HAS_WANDB: + _run_with_args(args) + return + + run = wandb.init( + project=args.project, + name=f"benchmark-{datetime.now(UTC).strftime('%m%d-%H%M%S')}", + tags=[ + "benchmark", + "robust-compare" + if _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) + else "single-mode", + ], + config={ + "run.kind": "benchmark", + "tiers": args.tiers, + "alpha_values": args.alpha_values, + "episodes": args.episodes, + "total_timesteps": args.total_timesteps, + "lambda_coi": args.lambda_coi, + "robust_radius": args.robust_radius, + "robust_points": args.robust_points, + "learning_rate": args.learning_rate, + "device": args.device, + }, + mode="offline" if args.offline else "online", + ) + try: + _run_with_args(args) + finally: + if run is not None: + wandb.finish() + + +if __name__ == "__main__": + run_cli() diff --git a/engine/jax/train.py b/engine/jax/train.py index 3860d8b..5ec637c 100644 --- a/engine/jax/train.py +++ b/engine/jax/train.py @@ -624,8 +624,8 @@ def evaluate_policy( revenues.append(ep_revenue) return { - "eval/reward": float(np.mean(rewards)), - "eval/revenue": float(np.mean(revenues)), + "eval/reward_mean": float(np.mean(rewards)), + "eval/revenue_mean": float(np.mean(revenues)), "eval/reward_std": float(np.std(rewards)), "eval/revenue_std": float(np.std(revenues)), } @@ -665,8 +665,8 @@ def _evaluate_q_network( revenues.append(ep_revenue) return { - "eval/reward": float(np.mean(rewards)), - "eval/revenue": float(np.mean(revenues)), + "eval/reward_mean": float(np.mean(rewards)), + "eval/revenue_mean": float(np.mean(revenues)), "eval/reward_std": float(np.std(rewards)), "eval/revenue_std": float(np.std(revenues)), } @@ -713,8 +713,8 @@ def _evaluate_q_table( revenues.append(ep_revenue) return { - "eval/reward": float(np.mean(rewards)), - "eval/revenue": float(np.mean(revenues)), + "eval/reward_mean": float(np.mean(rewards)), + "eval/revenue_mean": float(np.mean(revenues)), "eval/reward_std": float(np.std(rewards)), "eval/revenue_std": float(np.std(revenues)), } @@ -831,8 +831,8 @@ def _train_actor_critic( if is_primary and HAS_WANDB and wandb.run is not None: wandb.log( { - "train/reward": float(segment_values["reward"].mean()), - "train/revenue": float(segment_values["revenue"].mean()), + "train/reward_mean": float(segment_values["reward"].mean()), + "train/revenue_mean": float(segment_values["revenue"].mean()), "train/agent_prob": float(segment_values["agent_prob"].mean()), "train/alpha_adv": float(segment_values["alpha_adv"].mean()), "train/coi_leakage": float(segment_values["coi_leakage"].mean()), @@ -873,8 +873,8 @@ def _train_actor_critic( train_state = final_runner[0] denom = float(metric_count) if metric_count > 0 else 1.0 metrics = { - "train/reward": float(metric_sums["reward"] / denom), - "train/revenue": float(metric_sums["revenue"] / denom), + "train/reward_mean": float(metric_sums["reward"] / denom), + "train/revenue_mean": float(metric_sums["revenue"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), @@ -1052,14 +1052,14 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: ): wandb.log( { - "train/reward": metric_sums["reward"] / max(metric_count, 1), - "train/revenue": metric_sums["revenue"] / max(metric_count, 1), + "train/reward_mean": metric_sums["reward"] / max(metric_count, 1), + "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1), "train/agent_prob": metric_sums["agent_prob"] / max(metric_count, 1), "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), "train/coi_leakage": metric_sums["coi_leakage"] / max(metric_count, 1), - "train/dqn_loss": metric_sums["loss"] / max(loss_count, 1), + "train/loss": metric_sums["loss"] / max(loss_count, 1), "train/epsilon": epsilon_value, "train/global_step": global_step, }, @@ -1090,12 +1090,12 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: denom = float(metric_count) if metric_count > 0 else 1.0 metrics = { - "train/reward": float(metric_sums["reward"] / denom), - "train/revenue": float(metric_sums["revenue"] / denom), + "train/reward_mean": float(metric_sums["reward"] / denom), + "train/revenue_mean": float(metric_sums["revenue"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), - "train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)), + "train/loss": float(metric_sums["loss"] / max(loss_count, 1)), "train/global_step": total_steps, } @@ -1236,8 +1236,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float] ): wandb.log( { - "train/reward": metric_sums["reward"] / max(metric_count, 1), - "train/revenue": metric_sums["revenue"] / max(metric_count, 1), + "train/reward_mean": metric_sums["reward"] / max(metric_count, 1), + "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1), "train/agent_prob": metric_sums["agent_prob"] / max(metric_count, 1), "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), @@ -1269,8 +1269,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float] denom = float(metric_count) if metric_count > 0 else 1.0 metrics = { - "train/reward": float(metric_sums["reward"] / denom), - "train/revenue": float(metric_sums["revenue"] / denom), + "train/reward_mean": float(metric_sums["reward"] / denom), + "train/revenue_mean": float(metric_sums["revenue"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), diff --git a/engine/lib/__init__.py b/engine/lib/__init__.py index 4bfb923..823c572 100644 --- a/engine/lib/__init__.py +++ b/engine/lib/__init__.py @@ -1,38 +1,39 @@ -from .demand import estimate_demand, estimate_weighted_demand, generate_demand_for_actor -from .behavior import sample_behavior, get_transition_models, trajectory_to_events -from .render import DashboardRenderer, style_axis -from .wrappers import EconomicMetricsWrapper -from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback -from .providers import ( - ProviderBenchmark, - ProviderResult, - BenchmarkConfig, - RandomBaseline, - SurgeBaseline, -) -from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability -from .discrete import EventQTable +from __future__ import annotations -__all__ = [ - "estimate_demand", - "estimate_weighted_demand", - "generate_demand_for_actor", - "sample_behavior", - "get_transition_models", - "trajectory_to_events", - "DashboardRenderer", - "style_axis", - "EconomicMetricsWrapper", - "MetricsCallback", - "EvalMetricsCallback", - "CheckpointArtifactCallback", - "ProviderBenchmark", - "ProviderResult", - "BenchmarkConfig", - "RandomBaseline", - "SurgeBaseline", - "compute_uplift_coi", - "extract_purchases", - "compute_agent_probability", - "EventQTable", -] +from importlib import import_module + +_EXPORTS: dict[str, tuple[str, str]] = { + "estimate_demand": (".demand", "estimate_demand"), + "estimate_weighted_demand": (".demand", "estimate_weighted_demand"), + "generate_demand_for_actor": (".demand", "generate_demand_for_actor"), + "sample_behavior": (".behavior", "sample_behavior"), + "get_transition_models": (".behavior", "get_transition_models"), + "trajectory_to_events": (".behavior", "trajectory_to_events"), + "DashboardRenderer": (".render", "DashboardRenderer"), + "style_axis": (".render", "style_axis"), + "EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"), + "MetricsCallback": (".callbacks", "MetricsCallback"), + "EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"), + "CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"), + "ProviderBenchmark": (".providers", "ProviderBenchmark"), + "ProviderResult": (".providers", "ProviderResult"), + "BenchmarkConfig": (".providers", "BenchmarkConfig"), + "RandomBaseline": (".providers", "RandomBaseline"), + "SurgeBaseline": (".providers", "SurgeBaseline"), + "compute_uplift_coi": (".coi", "compute_uplift_coi"), + "extract_purchases": (".coi", "extract_purchases"), + "compute_agent_probability": (".coi", "compute_agent_probability"), + "EventQTable": (".discrete", "EventQTable"), +} + +__all__ = sorted(_EXPORTS) + + +def __getattr__(name: str): + if name not in _EXPORTS: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module_name, attr_name = _EXPORTS[name] + module = import_module(module_name, package=__name__) + value = getattr(module, attr_name) + globals()[name] = value + return value diff --git a/engine/lib/callbacks.py b/engine/lib/callbacks.py index 05e77a0..a21fdfe 100644 --- a/engine/lib/callbacks.py +++ b/engine/lib/callbacks.py @@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback): t = self.num_timesteps payload = { - "economics/revenue": econ["revenue"], - "economics/margin": econ["margin"], - "coi/level": econ["coi_level"], - "economics/regret": econ["regret"], + "train/revenue_step": econ["revenue"], + "train/margin_step": econ["margin"], + "train/coi_level": econ["coi_level"], + "train/regret_step": econ["regret"], } if "coi_mix" in econ: - payload["coi/mix"] = econ["coi_mix"] + payload["train/coi_mix"] = econ["coi_mix"] if "coi_base" in econ: - payload["coi/base"] = econ["coi_base"] + payload["train/coi_base"] = econ["coi_base"] if "coi_leakage" in econ: - payload["coi/leakage"] = econ["coi_leakage"] + payload["train/coi_leakage"] = econ["coi_leakage"] if "coi_penalty" in econ: - payload["coi/penalty"] = econ["coi_penalty"] + payload["train/coi_penalty"] = econ["coi_penalty"] wandb.log(payload, step=t) self._episode_revenues.append(econ["revenue"]) @@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback): return wandb.log( { - "episode/mean_revenue": np.mean(self._episode_revenues), - "episode/total_revenue": np.sum(self._episode_revenues), + "train/revenue_rollout_mean": np.mean(self._episode_revenues), + "train/revenue_rollout_total": np.sum(self._episode_revenues), }, step=self.num_timesteps, ) @@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback): if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): wandb.log( { - "eval/mean_reward": self.last_mean_reward, - "eval/mean_revenue": np.mean(self._eval_revenues) + "eval/reward_mean": self.last_mean_reward, + "eval/revenue_mean": np.mean(self._eval_revenues) if self._eval_revenues else 0, }, diff --git a/engine/lib/tiers.py b/engine/lib/tiers.py new file mode 100644 index 0000000..3760b34 --- /dev/null +++ b/engine/lib/tiers.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + +import numpy as np + + +class PolicyLike(Protocol): + def predict(self, obs: np.ndarray, deterministic: bool = True): ... + + +class StaticPolicy: + def __init__(self, n_actions: int): + self._action = int(max(0, n_actions // 2)) + + def predict(self, obs: np.ndarray, deterministic: bool = True): + return self._action, None + + +class SurgePolicy: + def __init__( + self, + n_actions: int, + n_products: int, + high_threshold: float = 60.0, + low_threshold: float = 30.0, + ): + self.n_actions = int(n_actions) + self.n_products = int(n_products) + self.mid = self.n_actions // 2 + self.high_t = float(high_threshold) + self.low_t = float(low_threshold) + + def predict(self, obs: np.ndarray, deterministic: bool = True): + obs_arr = np.asarray(obs, dtype=np.float32) + demand = obs_arr[: self.n_products] + demand_mean = float(np.mean(demand)) if demand.size > 0 else 0.0 + if demand_mean >= self.high_t: + return min(self.mid + 2, self.n_actions - 1), None + if demand_mean <= self.low_t: + return max(self.mid - 2, 0), None + return self.mid, None + + +@dataclass +class LinearElasticityPolicy: + n_actions: int + n_products: int + price_low: float + price_high: float + + def __post_init__(self): + self.n_actions = int(self.n_actions) + self.n_products = int(self.n_products) + self.price_low = float(self.price_low) + self.price_high = float(self.price_high) + self._target_price = 0.5 * (self.price_low + self.price_high) + self._action_scales = np.linspace(0.8, 1.2, self.n_actions) + + def fit(self, env, warmup_steps: int = 800, seed: int = 42): + rng = np.random.default_rng(int(seed)) + obs, _ = env.reset(seed=int(seed)) + prices: list[float] = [] + demands: list[float] = [] + + for _ in range(int(max(10, warmup_steps))): + action = int(rng.integers(0, self.n_actions)) + obs, _, term, trunc, info = env.step(action) + done = bool(term or trunc) + + p = np.asarray(info.get("prices", []), dtype=np.float32) + d = np.asarray(info.get("demand", []), dtype=np.float32) + if p.size > 0 and d.size > 0: + prices.append(float(np.mean(p))) + demands.append(float(np.mean(d))) + + if done: + obs, _ = env.reset() + + if len(prices) < 8: + self._target_price = 0.5 * (self.price_low + self.price_high) + return self + + slope, intercept = np.polyfit(np.asarray(prices), np.asarray(demands), 1) + if slope < -1e-6: + p_star = -intercept / (2.0 * slope) + self._target_price = float(np.clip(p_star, self.price_low, self.price_high)) + else: + self._target_price = 0.5 * (self.price_low + self.price_high) + return self + + def predict(self, obs: np.ndarray, deterministic: bool = True): + obs_arr = np.asarray(obs, dtype=np.float32) + cur_prices = obs_arr[self.n_products : 2 * self.n_products] + cur_mean = ( + float(np.mean(cur_prices)) if cur_prices.size > 0 else self._target_price + ) + scale = self._target_price / max(cur_mean, 1e-6) + action = int(np.argmin(np.abs(self._action_scales - scale))) + return int(np.clip(action, 0, self.n_actions - 1)), None diff --git a/engine/orchestrators/__init__.py b/engine/orchestrators/__init__.py new file mode 100644 index 0000000..1304822 --- /dev/null +++ b/engine/orchestrators/__init__.py @@ -0,0 +1,5 @@ +from .benchmark import run_benchmark_cli +from .sweep_agent import run_sweep_agent +from .train import run_train_once + +__all__ = ["run_benchmark_cli", "run_sweep_agent", "run_train_once"] diff --git a/engine/orchestrators/benchmark.py b/engine/orchestrators/benchmark.py new file mode 100644 index 0000000..eae938c --- /dev/null +++ b/engine/orchestrators/benchmark.py @@ -0,0 +1,7 @@ +from __future__ import annotations + + +def run_benchmark_cli(raw_args: list[str] | None = None) -> None: + from ..benchmark import run_cli + + run_cli(raw_args) diff --git a/engine/orchestrators/sweep_agent.py b/engine/orchestrators/sweep_agent.py new file mode 100644 index 0000000..9f3dcfc --- /dev/null +++ b/engine/orchestrators/sweep_agent.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any, Mapping, Sequence + +from ..spec import TrainSpec, run_name +from ..telemetry.wandb import ( + current_config, + finish_run, + get_wandb_module, + init_run, + run_agent, +) +from .train import run_with_active_sweep_run + + +def run_sweep_agent( + *, + project: str, + sweep_id: str, + count: int, + offline: bool, + no_wandb: bool, + base_overrides: Mapping[str, Any], + kind: str, + scenario: str, + group: str | None, + extra_tags: Sequence[str], +) -> None: + if no_wandb: + raise ValueError("sweep agent requires wandb") + if not sweep_id: + raise ValueError("--sweep-id is required with --sweep-agent") + if get_wandb_module() is None: + raise ImportError("wandb is required for sweep runs") + + mode = "offline" if offline else "online" + + def _sweep_trial() -> None: + run = init_run(mode=mode, project=project, group=group, sweep_mode=True) + try: + merged = dict(base_overrides) + merged.update(current_config()) + spec = TrainSpec.from_flat(merged) + if run is not None: + run.name = run_name(spec, kind=kind, scenario=scenario) + run_with_active_sweep_run( + spec, + kind=kind, + scenario=scenario, + group=group, + extra_tags=extra_tags, + ) + finally: + finish_run() + + run_agent( + sweep_id, + _sweep_trial, + count=count if count > 0 else None, + ) diff --git a/engine/orchestrators/train.py b/engine/orchestrators/train.py new file mode 100644 index 0000000..6b0f539 --- /dev/null +++ b/engine/orchestrators/train.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import json +from typing import Any, Sequence + +from ..spec import TrainSpec, run_metadata, run_name +from ..telemetry.wandb import ( + finish_run, + get_wandb_module, + init_run, + log_metrics, + update_run_config, + update_summary, +) +from ..train_core import run_train + + +def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list[str]: + tags = [ + kind, + spec.algorithm.name, + spec.runtime.backend, + "vanilla" if spec.study.no_robust else "robust", + ] + tags.extend([tag for tag in extra_tags if tag]) + return tags + + +def _print_local_metrics(metrics: dict[str, Any]) -> None: + print(json.dumps(metrics, indent=2)) + print("PHANTOM_METRICS:" + json.dumps(metrics)) + + +def _should_print_local(spec: TrainSpec) -> bool: + if not spec.runtime.use_jax: + return True + try: + import jax + + return int(jax.process_index()) == 0 + except Exception: + return True + + +def _is_non_primary_jax_worker(spec: TrainSpec) -> bool: + if not spec.runtime.use_jax: + return False + try: + import jax + + return int(jax.process_count()) > 1 and int(jax.process_index()) != 0 + except Exception: + return False + + +def run_train_once( + spec: TrainSpec, + *, + project: str, + offline: bool, + no_wandb: bool, + kind: str, + scenario: str, + group: str | None, + extra_tags: Sequence[str], +) -> dict[str, Any]: + wandb = get_wandb_module() + if no_wandb or wandb is None or _is_non_primary_jax_worker(spec): + result = run_train(spec) + if _should_print_local(spec): + _print_local_metrics(result.metrics) + return result.metrics + + mode = "offline" if offline else "online" + tags = _tags_for_run(spec, kind, extra_tags) + metadata = run_metadata( + spec, + kind=kind, + scenario=scenario, + group=group, + tags=tags, + ) + config = spec.to_flat_dict() + config.update(metadata) + name = run_name(spec, kind=kind, scenario=scenario) + init_run( + mode=mode, + project=project, + config=config, + name=name, + tags=tags, + group=group, + sweep_mode=False, + ) + + try: + result = run_train(spec) + metrics = result.metrics + step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) + log_metrics(metrics, step=step) + update_summary(metrics) + return metrics + finally: + finish_run() + + +def run_with_active_sweep_run( + spec: TrainSpec, + *, + kind: str, + scenario: str, + group: str | None, + extra_tags: Sequence[str], +) -> dict[str, Any]: + tags = _tags_for_run(spec, kind, extra_tags) + metadata = run_metadata( + spec, + kind=kind, + scenario=scenario, + group=group, + tags=tags, + ) + update_run_config({**spec.to_flat_dict(), **metadata}) + result = run_train(spec) + metrics = result.metrics + step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) + log_metrics(metrics, step=step) + update_summary(metrics) + return metrics diff --git a/engine/project.json b/engine/project.json index 10272c3..3cf3571 100644 --- a/engine/project.json +++ b/engine/project.json @@ -31,6 +31,26 @@ "cwd": "." } }, + "benchmark": { + "executor": "nx:run-commands", + "dependsOn": [ + "install" + ], + "options": { + "command": "bash scripts/nx_research.sh benchmark", + "cwd": "." + } + }, + "benchmark-agent": { + "executor": "nx:run-commands", + "dependsOn": [ + "install" + ], + "options": { + "command": "bash scripts/nx_research.sh benchmark-agent", + "cwd": "." + } + }, "train-agent": { "executor": "nx:run-commands", "dependsOn": [ diff --git a/engine/spec.py b/engine/spec.py new file mode 100644 index 0000000..f72fdd0 --- /dev/null +++ b/engine/spec.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import os +from typing import Any, Mapping, Sequence + + +def _truthy(value: str | bool | None) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]: + alias_map = { + "algorithm": "algo", + "algorithm.name": "algo", + "env.n_products": "n_products", + "env.action_levels": "action_levels", + "env.action_scale_low": "action_scale_low", + "env.action_scale_high": "action_scale_high", + "env.price_low": "price_low", + "env.price_high": "price_high", + "env.max_steps": "max_steps", + "env.margin_floor": "margin_floor", + "env.margin_floor_patience": "margin_floor_patience", + "env.n_sessions": "N", + "study.alpha": "alpha", + "study.lambda_coi": "lambda_coi", + "study.robust_radius": "robust_radius", + "study.robust_points": "robust_points", + "study.info_value": "info_value", + "study.revenue_weight": "revenue_weight", + "optimizer.learning_rate": "learning_rate", + "optimizer.gamma": "gamma", + "optimizer.batch_size": "batch_size", + "optimizer.n_steps": "n_steps", + "runtime.backend": "backend", + "runtime.device": "device", + "runtime.seed": "seed", + "runtime.total_timesteps": "total_timesteps", + "runtime.checkpoint_interval": "checkpoint_interval", + "eval.eval_freq": "eval_freq", + "eval.eval_episodes": "eval_episodes", + } + normalized: dict[str, Any] = {} + for key, value in raw.items(): + canonical = alias_map.get(str(key), str(key)) + normalized[canonical] = value + return normalized + + +@dataclass(frozen=True) +class AlgorithmSpec: + name: str = "ppo" + + +@dataclass(frozen=True) +class EnvSpec: + n_products: int = 10 + n_sessions: int = 100 + price_low: float = 10.0 + price_high: float = 150.0 + action_levels: int = 9 + action_scale_low: float = 0.8 + action_scale_high: float = 1.2 + max_steps: int = 100 + margin_floor: float = 0.05 + margin_floor_patience: int = 5 + + +@dataclass(frozen=True) +class StudySpec: + alpha: float = 0.3 + lambda_coi: float = 0.2 + robust_radius: float = 0.15 + robust_points: int = 5 + info_value: float = 1.0 + revenue_weight: float = 0.01 + no_robust: bool = False + + +@dataclass(frozen=True) +class OptimizerSpec: + learning_rate: float = 3e-4 + gamma: float = 0.99 + buffer_size: int = 50_000 + batch_size: int = 256 + tau: float = 0.005 + train_freq: int = 1 + learning_starts: int = 1_000 + target_update_interval: int = 1_000 + exploration_fraction: float = 0.2 + exploration_final_eps: float = 0.05 + n_steps: int = 2_048 + n_epochs: int = 10 + gae_lambda: float = 0.95 + clip_range: float = 0.2 + ent_coef: float = 0.0 + q_lr: float = 0.1 + q_bins: int = 6 + eps_start: float = 1.0 + eps_end: float = 0.05 + eps_decay: float = 0.9995 + arch: str = "small" + activation: str = "relu" + jax_num_envs: int = 16 + jax_num_steps: int = 128 + jax_num_minibatches: int = 4 + jax_update_epochs: int = 4 + jax_anneal_lr: bool = True + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + + +@dataclass(frozen=True) +class RuntimeSpec: + project: str = "capstone" + backend: str = "sb3" + device: str = "auto" + seed: int = 42 + total_timesteps: int = 50_000 + checkpoint_interval: int = 200_000 + model_dir: str = "engine/models" + log_freq: int = 100 + use_jax: bool = False + + +@dataclass(frozen=True) +class EvalSpec: + eval_freq: int = 1_000 + eval_episodes: int = 5 + robust_eval_enabled: bool = True + + +@dataclass(frozen=True) +class TrainSpec: + algorithm: AlgorithmSpec = field(default_factory=AlgorithmSpec) + env: EnvSpec = field(default_factory=EnvSpec) + study: StudySpec = field(default_factory=StudySpec) + optimizer: OptimizerSpec = field(default_factory=OptimizerSpec) + runtime: RuntimeSpec = field(default_factory=RuntimeSpec) + eval: EvalSpec = field(default_factory=EvalSpec) + + def to_flat_dict(self) -> dict[str, Any]: + return { + "project": self.runtime.project, + "algo": self.algorithm.name, + "seed": self.runtime.seed, + "total_timesteps": self.runtime.total_timesteps, + "eval_episodes": self.eval.eval_episodes, + "eval_freq": self.eval.eval_freq, + "log_freq": self.runtime.log_freq, + "model_dir": self.runtime.model_dir, + "backend": self.runtime.backend, + "device": self.runtime.device, + "use_jax": self.runtime.use_jax, + "checkpoint_interval": self.runtime.checkpoint_interval, + "n_products": self.env.n_products, + "N": self.env.n_sessions, + "price_low": self.env.price_low, + "price_high": self.env.price_high, + "action_levels": self.env.action_levels, + "action_scale_low": self.env.action_scale_low, + "action_scale_high": self.env.action_scale_high, + "max_steps": self.env.max_steps, + "margin_floor": self.env.margin_floor, + "margin_floor_patience": self.env.margin_floor_patience, + "alpha": self.study.alpha, + "lambda_coi": self.study.lambda_coi, + "robust_radius": self.study.robust_radius, + "robust_points": self.study.robust_points, + "info_value": self.study.info_value, + "revenue_weight": self.study.revenue_weight, + "no_robust": self.study.no_robust, + "learning_rate": self.optimizer.learning_rate, + "gamma": self.optimizer.gamma, + "buffer_size": self.optimizer.buffer_size, + "batch_size": self.optimizer.batch_size, + "tau": self.optimizer.tau, + "train_freq": self.optimizer.train_freq, + "learning_starts": self.optimizer.learning_starts, + "target_update_interval": self.optimizer.target_update_interval, + "exploration_fraction": self.optimizer.exploration_fraction, + "exploration_final_eps": self.optimizer.exploration_final_eps, + "n_steps": self.optimizer.n_steps, + "n_epochs": self.optimizer.n_epochs, + "gae_lambda": self.optimizer.gae_lambda, + "clip_range": self.optimizer.clip_range, + "ent_coef": self.optimizer.ent_coef, + "q_lr": self.optimizer.q_lr, + "q_bins": self.optimizer.q_bins, + "eps_start": self.optimizer.eps_start, + "eps_end": self.optimizer.eps_end, + "eps_decay": self.optimizer.eps_decay, + "arch": self.optimizer.arch, + "activation": self.optimizer.activation, + "jax_num_envs": self.optimizer.jax_num_envs, + "jax_num_steps": self.optimizer.jax_num_steps, + "jax_num_minibatches": self.optimizer.jax_num_minibatches, + "jax_update_epochs": self.optimizer.jax_update_epochs, + "jax_anneal_lr": self.optimizer.jax_anneal_lr, + "vf_coef": self.optimizer.vf_coef, + "max_grad_norm": self.optimizer.max_grad_norm, + "robust_eval_enabled": self.eval.robust_eval_enabled, + } + + @classmethod + def from_flat( + cls, + raw: Mapping[str, Any] | None = None, + *, + env_vars: Mapping[str, str] | None = None, + ) -> "TrainSpec": + base = cls().to_flat_dict() + incoming = _normalize_keys(raw or {}) + base.update({k: v for k, v in incoming.items() if v is not None}) + + runtime_env = os.environ if env_vars is None else env_vars + base["device"] = str( + base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto")) + ) + + requested_jax = _truthy(base.get("use_jax")) or _truthy( + runtime_env.get("PHANTOM_USE_JAX") + ) + backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower() + if backend == "auto": + backend = "jax" if requested_jax else "sb3" + if backend == "jax": + requested_jax = True + + no_robust = _truthy(base.get("no_robust")) + if no_robust: + base["lambda_coi"] = 0.0 + base["robust_radius"] = 0.0 + base["robust_points"] = 1 + + return cls( + algorithm=AlgorithmSpec(name=str(base["algo"]).lower().strip()), + env=EnvSpec( + n_products=int(base["n_products"]), + n_sessions=int(base["N"]), + price_low=float(base["price_low"]), + price_high=float(base["price_high"]), + action_levels=int(base["action_levels"]), + action_scale_low=float(base["action_scale_low"]), + action_scale_high=float(base["action_scale_high"]), + max_steps=int(base["max_steps"]), + margin_floor=float(base["margin_floor"]), + margin_floor_patience=int(base["margin_floor_patience"]), + ), + study=StudySpec( + alpha=float(base["alpha"]), + lambda_coi=float(base["lambda_coi"]), + robust_radius=float(base["robust_radius"]), + robust_points=int(base["robust_points"]), + info_value=float(base["info_value"]), + revenue_weight=float(base["revenue_weight"]), + no_robust=no_robust, + ), + optimizer=OptimizerSpec( + learning_rate=float(base["learning_rate"]), + gamma=float(base["gamma"]), + buffer_size=int(base["buffer_size"]), + batch_size=int(base["batch_size"]), + tau=float(base["tau"]), + train_freq=int(base["train_freq"]), + learning_starts=int(base["learning_starts"]), + target_update_interval=int(base["target_update_interval"]), + exploration_fraction=float(base["exploration_fraction"]), + exploration_final_eps=float(base["exploration_final_eps"]), + n_steps=int(base["n_steps"]), + n_epochs=int(base["n_epochs"]), + gae_lambda=float(base["gae_lambda"]), + clip_range=float(base["clip_range"]), + ent_coef=float(base["ent_coef"]), + q_lr=float(base["q_lr"]), + q_bins=int(base["q_bins"]), + eps_start=float(base["eps_start"]), + eps_end=float(base["eps_end"]), + eps_decay=float(base["eps_decay"]), + arch=str(base["arch"]), + activation=str(base["activation"]), + jax_num_envs=int(base["jax_num_envs"]), + jax_num_steps=int(base["jax_num_steps"]), + jax_num_minibatches=int(base["jax_num_minibatches"]), + jax_update_epochs=int(base["jax_update_epochs"]), + jax_anneal_lr=_truthy(base.get("jax_anneal_lr")), + vf_coef=float(base["vf_coef"]), + max_grad_norm=float(base["max_grad_norm"]), + ), + runtime=RuntimeSpec( + project=str(base["project"]), + backend=backend, + device=str(base["device"]), + seed=int(base["seed"]), + total_timesteps=int(base["total_timesteps"]), + checkpoint_interval=int(base["checkpoint_interval"]), + model_dir=str(base["model_dir"]), + log_freq=int(base["log_freq"]), + use_jax=requested_jax, + ), + eval=EvalSpec( + eval_freq=int(base["eval_freq"]), + eval_episodes=int(base["eval_episodes"]), + robust_eval_enabled=_truthy(base.get("robust_eval_enabled", True)), + ), + ) + + +def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str: + return ( + f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/" + f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}" + ) + + +def run_metadata( + spec: TrainSpec, + *, + kind: str, + scenario: str, + group: str | None = None, + tags: Sequence[str] = (), +) -> dict[str, Any]: + metadata: dict[str, Any] = { + "run.kind": str(kind), + "run.algo": spec.algorithm.name, + "run.backend": spec.runtime.backend, + "run.device": spec.runtime.device, + "run.scenario": str(scenario), + "run.seed": spec.runtime.seed, + "run.tags": list(tags), + } + if group: + metadata["run.group"] = group + return metadata diff --git a/engine/studies/local_comparison.py b/engine/studies/local_comparison.py new file mode 100644 index 0000000..1859b97 --- /dev/null +++ b/engine/studies/local_comparison.py @@ -0,0 +1,136 @@ +import sys +import numpy as np +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt + +from gymnasium.wrappers import FlattenObservation +from stable_baselines3 import PPO + +# Add parent directory to path to allow importing engine +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from engine.wrapper import PHANTOM +from engine.lib.wrappers import EconomicMetricsWrapper +from engine.lib.providers import ( + ProviderBenchmark, + BenchmarkConfig, + RandomBaseline, + SurgeBaseline, +) + + +def env_factory(alpha: float): + """Creates a wrapped PHANTOM environment for testing at a specific alpha level.""" + # Action levels=9 matches the trained PPO model + # n_products=8 matches the pretrained model's expectation of Box(16,) + env = PHANTOM( + n_products=8, + alpha=alpha, + N=100, + action_levels=9, + action_scale_low=0.8, + action_scale_high=1.2, + max_steps=20, # Short episodes so simulation goes fast + robust_points=1, # disable expensive adversarial lookaheads + render_mode=None, + ) + env = EconomicMetricsWrapper(env) + return FlattenObservation(env) + + +def main(): + print("Loading pre-trained Robust RL model...") + model_path = Path(__file__).parent.parent / "models" / "phantom_ppo.zip" + if not model_path.exists(): + print(f"Error: Model not found at {model_path}") + print("Please ensure you have a trained model before running this script.") + return + + rl_model = PPO.load(model_path) + + # The action space is Discrete(9). Index 4 is the middle (1.0 scale). + n_actions = 9 + mid_action = n_actions // 2 + + providers = { + "Static (Base)": lambda obs: mid_action, + "Random": RandomBaseline(n_actions), + "Heuristic Surge": SurgeBaseline( + n_actions, high_threshold=60.0, low_threshold=30.0 + ), + "Robust RL (PPO)": lambda obs: rl_model.predict(obs, deterministic=True)[0], + } + + config = BenchmarkConfig( + n_episodes=10, # Lower episodes to run faster + alpha_range=[0.0, 0.5, 1.0], # Fewer alpha levels + baseline_name="Static (Base)", + ) + + print(f"\nStarting benchmark across alpha levels: {config.alpha_range}") + print( + f"Testing {len(providers)} strategies for {config.n_episodes} episodes each...\n" + ) + + benchmark = ProviderBenchmark(env_factory, providers, config) + results = benchmark.run() + + # 1. Print tabular results + df = benchmark.to_dataframe() + summary = benchmark.summary_table() + print("\n--- Benchmark Summary Table ---") + print(summary) + + # 2. Save results to CSV for thesis inclusion + out_dir = Path(__file__).parent / "results" + out_dir.mkdir(exist_ok=True) + csv_path = out_dir / "provider_comparison.csv" + df.to_csv(csv_path, index=False) + print(f"\nSaved raw results to {csv_path}") + + # 3. Plot the degradation of COI / Revenue as alpha increases + plt.figure(figsize=(12, 5)) + + # Plot 1: Revenue vs Alpha + plt.subplot(1, 2, 1) + for name in providers.keys(): + provider_data = df[df["name"] == name] + plt.plot( + provider_data["alpha"], + provider_data["mean_revenue"], + marker="o", + label=name, + linewidth=2, + ) + plt.title("Revenue under Agent Contamination") + plt.xlabel("Contamination Level (α)") + plt.ylabel("Mean Episode Revenue ($)") + plt.grid(True, linestyle="--", alpha=0.7) + plt.legend() + + # Plot 2: COI Preservation vs Alpha + plt.subplot(1, 2, 2) + for name in providers.keys(): + provider_data = df[df["name"] == name] + plt.plot( + provider_data["alpha"], + provider_data["coi_preserved_pct"], + marker="s", + label=name, + linewidth=2, + ) + plt.title("Cost of Information (COI) Preservation") + plt.xlabel("Contamination Level (α)") + plt.ylabel("COI Preserved (%)") + plt.grid(True, linestyle="--", alpha=0.7) + plt.legend() + + plt.tight_layout() + plot_path = out_dir / "alpha_degradation_plot.png" + plt.savefig(plot_path, dpi=300) + print(f"Saved visualization to {plot_path}") + + +if __name__ == "__main__": + main() diff --git a/engine/sweeps/model_mix.yaml b/engine/sweeps/model_mix.yaml index 28a7f38..636eec2 100644 --- a/engine/sweeps/model_mix.yaml +++ b/engine/sweeps/model_mix.yaml @@ -1,6 +1,6 @@ method: random metric: - name: sweep/score + name: objective/score goal: maximize command: - ${env} diff --git a/engine/sweeps/models_only.yaml b/engine/sweeps/models_only.yaml index e0bd708..3e0ca9d 100644 --- a/engine/sweeps/models_only.yaml +++ b/engine/sweeps/models_only.yaml @@ -1,6 +1,6 @@ method: grid metric: - name: sweep/score + name: objective/score goal: maximize run_cap: 4 command: diff --git a/engine/sweeps/sac_tune.yaml b/engine/sweeps/sac_tune.yaml index 97558cf..faf9327 100644 --- a/engine/sweeps/sac_tune.yaml +++ b/engine/sweeps/sac_tune.yaml @@ -1,6 +1,6 @@ method: bayes metric: - name: sweep/score + name: objective/score goal: maximize command: - ${env} diff --git a/engine/sweeps/small_arch_compare.yaml b/engine/sweeps/small_arch_compare.yaml index 2eae9a0..aa1fd7b 100644 --- a/engine/sweeps/small_arch_compare.yaml +++ b/engine/sweeps/small_arch_compare.yaml @@ -1,6 +1,6 @@ method: random metric: - name: sweep/score + name: objective/score goal: maximize command: - ${env} diff --git a/engine/sweeps/tpu_jax.yaml b/engine/sweeps/tpu_jax.yaml index 6b4e001..2e5de08 100644 --- a/engine/sweeps/tpu_jax.yaml +++ b/engine/sweeps/tpu_jax.yaml @@ -1,6 +1,6 @@ method: bayes metric: - name: sweep/score + name: objective/score goal: maximize command: - ${env} diff --git a/engine/sweeps/tpu_pod.yaml b/engine/sweeps/tpu_pod.yaml index 35d8ded..d34dfb1 100644 --- a/engine/sweeps/tpu_pod.yaml +++ b/engine/sweeps/tpu_pod.yaml @@ -1,6 +1,6 @@ method: bayes metric: - name: sweep/score + name: objective/score goal: maximize command: - ${env} diff --git a/engine/telemetry/__init__.py b/engine/telemetry/__init__.py new file mode 100644 index 0000000..ee1bf93 --- /dev/null +++ b/engine/telemetry/__init__.py @@ -0,0 +1,23 @@ +from .metrics import canonicalize_metrics +from .wandb import ( + current_config, + finish_run, + get_wandb_module, + init_run, + log_metrics, + run_agent, + update_run_config, + update_summary, +) + +__all__ = [ + "canonicalize_metrics", + "current_config", + "finish_run", + "get_wandb_module", + "init_run", + "log_metrics", + "run_agent", + "update_run_config", + "update_summary", +] diff --git a/engine/telemetry/metrics.py b/engine/telemetry/metrics.py new file mode 100644 index 0000000..43ee0a4 --- /dev/null +++ b/engine/telemetry/metrics.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any, Mapping + +from ..spec import TrainSpec + + +_ALIASES = { + "train/reward": "train/reward_mean", + "train/revenue": "train/revenue_mean", + "train/dqn_loss": "train/loss", + "eval/reward": "eval/reward_mean", + "eval/revenue": "eval/revenue_mean", + "train/steps_per_second": "runtime/steps_per_second", +} + + +def _as_float(value: Any, default: float | None = None) -> float | None: + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, Any]: + metrics: dict[str, Any] = {} + for key, value in raw.items(): + canonical = _ALIASES.get(str(key), str(key)) + if canonical in metrics and canonical != key: + continue + metrics[canonical] = value + + metrics.setdefault("train/global_step", spec.runtime.total_timesteps) + + eval_reward = _as_float(metrics.get("eval/reward_mean"), 0.0) or 0.0 + eval_revenue = _as_float(metrics.get("eval/revenue_mean"), 0.0) or 0.0 + metrics["objective/score"] = eval_reward + spec.study.revenue_weight * eval_revenue + + margin_mean = _as_float(metrics.get("eval/margin_mean"), None) + if margin_mean is not None: + metrics["objective/constraint_margin"] = margin_mean - spec.env.margin_floor + + coi_level = _as_float(metrics.get("eval/coi_level_mean"), None) + metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level + + metrics["study/alpha"] = spec.study.alpha + metrics["study/lambda_coi"] = spec.study.lambda_coi + metrics["study/robust_radius"] = spec.study.robust_radius + metrics["study/info_value"] = spec.study.info_value + + metrics["runtime/backend"] = spec.runtime.backend + metrics["runtime/device"] = spec.runtime.device + metrics["runtime/seed"] = spec.runtime.seed + + return metrics diff --git a/engine/telemetry/wandb.py b/engine/telemetry/wandb.py new file mode 100644 index 0000000..5e6fb85 --- /dev/null +++ b/engine/telemetry/wandb.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any, Callable, Iterable, Mapping + + +def get_wandb_module(): + try: + import wandb + + return wandb + except ImportError: + return None + + +def _require_wandb(): + wandb = get_wandb_module() + if wandb is None: + raise ImportError("wandb is required for this workflow") + return wandb + + +def init_run( + *, + mode: str, + project: str | None = None, + config: Mapping[str, Any] | None = None, + name: str | None = None, + tags: Iterable[str] | None = None, + group: str | None = None, + sweep_mode: bool = False, +): + wandb = _require_wandb() + kwargs: dict[str, Any] = {"mode": mode} + if group: + kwargs["group"] = group + if sweep_mode: + run = wandb.init(**kwargs) + if name and run is not None: + run.name = name + return run + + init_kwargs = dict(kwargs) + init_kwargs["project"] = project + if config is not None: + init_kwargs["config"] = dict(config) + if name: + init_kwargs["name"] = name + if tags: + init_kwargs["tags"] = list(tags) + return wandb.init(**init_kwargs) + + +def finish_run() -> None: + wandb = get_wandb_module() + if wandb is not None and wandb.run is not None: + wandb.finish() + + +def current_config() -> dict[str, Any]: + wandb = get_wandb_module() + if wandb is None or wandb.run is None: + return {} + return {key: wandb.config[key] for key in wandb.config.keys()} + + +def update_run_config(config: Mapping[str, Any]) -> None: + wandb = get_wandb_module() + if wandb is None or wandb.run is None: + return + try: + wandb.config.update(dict(config), allow_val_change=True) + except TypeError: + wandb.config.update(dict(config)) + + +def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None: + wandb = get_wandb_module() + if wandb is None or wandb.run is None: + return + wandb.log(dict(metrics), step=step) + + +def update_summary(metrics: Mapping[str, Any]) -> None: + wandb = get_wandb_module() + if wandb is None or wandb.run is None: + return + for key, value in metrics.items(): + wandb.run.summary[key] = value + + +def run_agent( + sweep_id: str, + fn: Callable[[], None], + *, + count: int | None = None, +) -> None: + wandb = _require_wandb() + wandb.agent(sweep_id, function=fn, count=count) diff --git a/engine/train.py b/engine/train.py index 4d52a50..90ac991 100644 --- a/engine/train.py +++ b/engine/train.py @@ -1,98 +1,10 @@ from __future__ import annotations import argparse -import json -import os -from pathlib import Path -from typing import TYPE_CHECKING -import numpy as np +from typing import Any -if TYPE_CHECKING: - from .lib.discrete import EventQTable - -from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint - -try: - import wandb as _wandb - - if hasattr(_wandb, "init") and callable(_wandb.init): - wandb = _wandb - HAS_WANDB = True - else: - wandb = None - HAS_WANDB = False -except ImportError: - wandb = None - HAS_WANDB = False - -try: - from stable_baselines3 import PPO, A2C, DQN - from stable_baselines3.common.callbacks import EvalCallback - from stable_baselines3.common.monitor import Monitor - - HAS_SB3 = True -except ImportError: - HAS_SB3 = False - -from .jax import JAX_AVAILABLE - - -DEFAULT_CFG = { - "project": "phantom-pricing", - "algo": "ppo", - "seed": 42, - "total_timesteps": 50_000, - "eval_episodes": 5, - "eval_freq": 1_000, - "log_freq": 100, - "revenue_weight": 0.01, - "n_products": 10, - "N": 100, - "alpha": 0.3, - "lambda_coi": 0.2, - "robust_radius": 0.15, - "robust_points": 5, - "no_robust": False, - "info_value": 1.0, - "price_low": 10.0, - "price_high": 150.0, - "action_levels": 9, - "action_scale_low": 0.8, - "action_scale_high": 1.2, - "learning_rate": 3e-4, - "gamma": 0.99, - "buffer_size": 50_000, - "batch_size": 256, - "tau": 0.005, - "train_freq": 1, - "learning_starts": 1_000, - "target_update_interval": 1_000, - "exploration_fraction": 0.2, - "exploration_final_eps": 0.05, - "n_steps": 2_048, - "n_epochs": 10, - "gae_lambda": 0.95, - "clip_range": 0.2, - "ent_coef": 0.0, - "q_lr": 0.1, - "eps_start": 1.0, - "eps_end": 0.05, - "eps_decay": 0.9995, - "model_dir": "engine/models", - "arch": "small", - "activation": "relu", - "q_bins": 6, - "max_steps": 100, - "margin_floor": 0.05, - "margin_floor_patience": 5, - "use_jax": False, - "jax_num_envs": 16, - "jax_num_steps": 128, - "jax_num_minibatches": 4, - "jax_update_epochs": 4, - "jax_anneal_lr": True, - "checkpoint_interval": 200_000, -} +from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once +from .spec import TrainSpec def _truthy(value: str | bool | None) -> bool: @@ -103,423 +15,133 @@ def _truthy(value: str | bool | None) -> bool: return str(value).strip().lower() in {"1", "true", "yes", "on"} -def _cfg(raw: dict | None = None) -> dict: - cfg = dict(DEFAULT_CFG) - if raw: - cfg.update({k: v for k, v in raw.items() if v is not None}) - cfg["algo"] = str(cfg["algo"]).lower() - cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy( - os.environ.get("PHANTOM_USE_JAX") +def _parse_tags(raw: str | None) -> list[str]: + if raw is None: + return [] + return [piece.strip() for piece in str(raw).split(",") if piece.strip()] + + +def _probe_run_kind(argv: list[str]) -> str: + probe = argparse.ArgumentParser(add_help=False) + probe.add_argument("--run-kind", choices=["train", "benchmark"]) + probe.add_argument("--run-mode", choices=["train", "benchmark"]) + args, _ = probe.parse_known_args(argv) + return str(args.run_kind or args.run_mode or "train") + + +def _strip_run_kind(argv: list[str]) -> list[str]: + stripped: list[str] = [] + skip_next = False + for item in argv: + if skip_next: + skip_next = False + continue + if item in {"--run-kind", "--run-mode"}: + skip_next = True + continue + if item.startswith("--run-kind=") or item.startswith("--run-mode="): + continue + stripped.append(item) + return stripped + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="PHANTOM unified training entrypoint") + parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train") + parser.add_argument("--run-mode", choices=["train", "benchmark"]) + + parser.add_argument("--project", default="capstone") + parser.add_argument("--scenario", default="default") + parser.add_argument("--group", type=str) + parser.add_argument("--tags", type=str) + + parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto") + parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"]) + parser.add_argument("--seed", type=int) + parser.add_argument("--total-timesteps", type=int) + parser.add_argument("--model-dir", type=str) + parser.add_argument("--log-freq", type=int) + parser.add_argument("--checkpoint-interval", type=int) + parser.add_argument("--device", type=str) + + parser.add_argument("--alpha", type=float) + parser.add_argument("--N", type=int) + parser.add_argument("--n-products", type=int) + parser.add_argument("--lambda-coi", type=float) + parser.add_argument("--info-value", type=float) + parser.add_argument("--robust-radius", type=float) + parser.add_argument("--robust-points", type=int) + parser.add_argument("--no-robust", action="store_true") + parser.add_argument("--revenue-weight", type=float) + + parser.add_argument("--price-low", type=float) + parser.add_argument("--price-high", type=float) + parser.add_argument("--action-levels", type=int) + parser.add_argument("--action-scale-low", type=float) + parser.add_argument("--action-scale-high", type=float) + parser.add_argument("--max-steps", type=int) + parser.add_argument("--margin-floor", type=float) + parser.add_argument("--margin-floor-patience", type=int) + + parser.add_argument("--learning-rate", type=float) + parser.add_argument("--gamma", type=float) + parser.add_argument("--buffer-size", type=int) + parser.add_argument("--batch-size", type=int) + parser.add_argument("--tau", type=float) + parser.add_argument("--train-freq", type=int) + parser.add_argument("--learning-starts", type=int) + parser.add_argument("--target-update-interval", type=int) + parser.add_argument("--exploration-fraction", type=float) + parser.add_argument("--exploration-final-eps", type=float) + parser.add_argument("--n-steps", type=int) + parser.add_argument("--n-epochs", type=int) + parser.add_argument("--gae-lambda", type=float) + parser.add_argument("--clip-range", type=float) + parser.add_argument("--ent-coef", type=float) + parser.add_argument("--q-lr", type=float) + parser.add_argument("--q-bins", type=int) + parser.add_argument("--eps-start", type=float) + parser.add_argument("--eps-end", type=float) + parser.add_argument("--eps-decay", type=float) + parser.add_argument("--arch", type=str) + parser.add_argument("--activation", type=str) + parser.add_argument("--vf-coef", type=float) + parser.add_argument("--max-grad-norm", type=float) + + parser.add_argument("--eval-freq", type=int) + parser.add_argument("--eval-episodes", type=int) + + parser.add_argument("--jax", action="store_true") + parser.add_argument("--jax-num-envs", type=int) + parser.add_argument("--jax-num-steps", type=int) + parser.add_argument("--jax-num-minibatches", type=int) + parser.add_argument("--jax-update-epochs", type=int) + parser.add_argument("--jax-anneal-lr", type=str) + + parser.add_argument("--sweep-agent", action="store_true") + parser.add_argument("--sweep-id", type=str) + parser.add_argument("--count", type=int, default=0) + parser.add_argument("--offline", action="store_true") + parser.add_argument("--no-wandb", action="store_true") + return parser + + +def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: + jax_anneal_lr = ( + _truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None ) - cfg["no_robust"] = _truthy(cfg.get("no_robust")) - if cfg["no_robust"]: - cfg["lambda_coi"] = 0.0 - cfg["robust_radius"] = 0.0 - cfg["robust_points"] = 1 - return cfg - - -def _wandb_cfg_dict() -> dict: - return ( - {k: wandb.config[k] for k in wandb.config.keys()} - if HAS_WANDB and wandb.run - else {} - ) - - -def make_env(cfg: dict): - from gymnasium.wrappers import FlattenObservation - - from .wrapper import PHANTOM - from .lib.wrappers import EconomicMetricsWrapper - - env = PHANTOM( - n_products=int(cfg["n_products"]), - alpha=float(cfg["alpha"]), - N=int(cfg["N"]), - price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])), - lambda_coi=float(cfg["lambda_coi"]), - robust_radius=float(cfg["robust_radius"]), - robust_points=int(cfg["robust_points"]), - info_value=float(cfg["info_value"]), - action_levels=int(cfg["action_levels"]), - action_scale_low=float(cfg["action_scale_low"]), - action_scale_high=float(cfg["action_scale_high"]), - max_steps=int(cfg.get("max_steps", 100)), - margin_floor=float(cfg.get("margin_floor", 0.05)), - margin_floor_patience=int(cfg.get("margin_floor_patience", 5)), - render_mode=None, - ) - env = EconomicMetricsWrapper(env) - env = FlattenObservation(env) - return env - - -def _net_arch(name) -> list[int]: - presets = { - "tiny": [32, 32], - "small": [64, 64], - "medium": [128, 128], - "large": [256, 256], - } - if isinstance(name, (list, tuple)): - return [int(v) for v in name] - s = str(name).lower().strip() - if s in presets: - return presets[s] - if "x" in s: - try: - vals = [int(v) for v in s.split("x") if v] - return vals if vals else presets["small"] - except ValueError: - return presets["small"] - return presets["small"] - - -def _activation(name): - try: - import torch.nn as nn - except ImportError: - return None - return { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "elu": nn.ELU, - "leaky_relu": nn.LeakyReLU, - }.get(str(name).lower().strip(), nn.ReLU) - - -def _policy_kwargs(cfg: dict) -> dict: - kw = {"net_arch": _net_arch(cfg.get("arch", "small"))} - act = _activation(cfg.get("activation", "relu")) - if act is not None: - kw["activation_fn"] = act - return kw - - -def _action(agent, obs, deterministic: bool = True): - out = agent.predict(obs, deterministic=deterministic) - a = out[0] if isinstance(out, tuple) else out - if isinstance(a, np.ndarray) and a.size == 1: - return int(a.reshape(-1)[0]) - return a - - -def evaluate(agent, env, episodes: int) -> dict: - rewards, revenues = [], [] - for _ in range(int(episodes)): - obs, _ = env.reset() - done, ep_r, ep_rev = False, 0.0, 0.0 - while not done: - obs, reward, term, trunc, info = env.step(_action(agent, obs, True)) - done = term or trunc - ep_r += float(reward) - ep_rev += float( - info.get("economics", {}).get("revenue", info.get("revenue", 0.0)) - ) - rewards.append(ep_r) - revenues.append(ep_rev) - return { - "eval/reward": float(np.mean(rewards)), - "eval/revenue": float(np.mean(revenues)), - "eval/reward_std": float(np.std(rewards)), - "eval/revenue_std": float(np.std(revenues)), - } - - -def build_model(cfg: dict, env): - algo = cfg["algo"] - policy_kwargs = _policy_kwargs(cfg) - if algo == "sac": - raise ValueError("sac is not supported with the discrete core env") - if algo == "ppo": - return PPO( - "MlpPolicy", - env, - verbose=1, - policy_kwargs=policy_kwargs, - seed=int(cfg["seed"]), - learning_rate=float(cfg["learning_rate"]), - n_steps=int(cfg["n_steps"]), - batch_size=int(cfg["batch_size"]), - n_epochs=int(cfg["n_epochs"]), - gamma=float(cfg["gamma"]), - gae_lambda=float(cfg["gae_lambda"]), - clip_range=float(cfg["clip_range"]), - ent_coef=float(cfg["ent_coef"]), - ) - if algo == "a2c": - return A2C( - "MlpPolicy", - env, - verbose=1, - policy_kwargs=policy_kwargs, - seed=int(cfg["seed"]), - learning_rate=float(cfg["learning_rate"]), - n_steps=max(5, int(cfg["n_steps"]) // 32), - gamma=float(cfg["gamma"]), - gae_lambda=float(cfg["gae_lambda"]), - ent_coef=float(cfg["ent_coef"]), - ) - if algo == "dqn": - return DQN( - "MlpPolicy", - env, - verbose=1, - policy_kwargs=policy_kwargs, - seed=int(cfg["seed"]), - learning_rate=float(cfg["learning_rate"]), - buffer_size=int(cfg["buffer_size"]), - batch_size=int(cfg["batch_size"]), - gamma=float(cfg["gamma"]), - train_freq=int(cfg["train_freq"]), - learning_starts=int(cfg["learning_starts"]), - target_update_interval=int(cfg["target_update_interval"]), - exploration_fraction=float(cfg["exploration_fraction"]), - exploration_final_eps=float(cfg["exploration_final_eps"]), - ) - raise ValueError(f"unsupported algo '{algo}'") - - -def _sb3_model_cls(algo: str): - if algo == "ppo": - return PPO - if algo == "a2c": - return A2C - if algo == "dqn": - return DQN - raise ValueError(f"unsupported algo '{algo}'") - - -def train_qtable(cfg: dict) -> tuple["EventQTable", dict]: - from .lib.discrete import EventQTable - - np.random.seed(int(cfg["seed"])) - env = make_env(cfg) - eval_env = make_env(cfg) - agent = EventQTable( - env.action_space.n, - int(cfg["n_products"]), - (float(cfg["price_low"]), float(cfg["price_high"])), - lr=float(cfg["q_lr"]), - gamma=float(cfg["gamma"]), - n_bins=int(cfg["q_bins"]), - ) - eps = float(cfg["eps_start"]) - obs, _ = env.reset(seed=int(cfg["seed"])) - for t in range(int(cfg["total_timesteps"])): - a, s = agent.act(obs, eps) - nxt, reward, term, trunc, info = env.step(a) - done = term or trunc - agent.update(s, a, float(reward), agent.encode(nxt), done) - eps = max(float(cfg["eps_end"]), eps * float(cfg["eps_decay"])) - if HAS_WANDB and wandb.run and (t + 1) % int(cfg["log_freq"]) == 0: - econ = info.get("economics", {}) - wandb.log( - { - "train/reward": float(reward), - "train/revenue": float(econ.get("revenue", 0.0)), - "train/epsilon": float(eps), - }, - step=t + 1, - ) - obs = env.reset()[0] if done else nxt - metrics = evaluate(agent, eval_env, int(cfg["eval_episodes"])) - metrics["train/global_step"] = int(cfg["total_timesteps"]) - env.close() - eval_env.close() - return agent, metrics - - -def train_sb3(cfg: dict) -> tuple[object, dict]: - if not HAS_SB3: - raise ImportError("stable-baselines3 is required for SB3 models") - from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback - - env = make_env(cfg) - eval_env = make_env(cfg) - env = Monitor(env) - eval_env = Monitor(eval_env) - model = build_model(cfg, env) - resume_step = 0 - if HAS_WANDB and wandb.run is not None: - sweep_id = getattr(wandb.run, "sweep_id", None) - artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id) - checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip" - restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file) - if restored is not None: - checkpoint_path, metadata = restored - model = _sb3_model_cls(cfg["algo"]).load( - checkpoint_path.as_posix(), env=env - ) - resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0))) - model.num_timesteps = max( - int(getattr(model, "num_timesteps", 0)), resume_step - ) - - cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))] - cbs.append( - CheckpointArtifactCallback( - cfg, - interval=int(cfg.get("checkpoint_interval", 10_000)), - ) - ) - cbs.append( - EvalCallback( - eval_env, - eval_freq=int(cfg["eval_freq"]), - n_eval_episodes=int(cfg["eval_episodes"]), - deterministic=True, - verbose=0, - ) - ) - target_steps = int(cfg["total_timesteps"]) - remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0))) - if remaining_steps > 0: - model.learn( - total_timesteps=remaining_steps, - callback=cbs, - reset_num_timesteps=False, - ) - - model_path = Path(cfg["model_dir"]) - model_path.mkdir(parents=True, exist_ok=True) - model.save(str(model_path / f"phantom_{cfg['algo']}")) - metrics = evaluate(model, eval_env, int(cfg["eval_episodes"])) - metrics["train/global_step"] = int(model.num_timesteps) - env.close() - eval_env.close() - return model, metrics - - -def train_once(cfg: dict) -> dict: - algo = cfg["algo"] - if cfg.get("use_jax"): - if not JAX_AVAILABLE: - raise ImportError( - "JAX backend requested but JAX is not installed. " - "Install engine/jax/requirements.txt and jax[tpu] for TPU runs." - ) - try: - from .jax.train import train_jax - except Exception as exc: # pragma: no cover - raise ImportError(f"Failed to import JAX trainer: {exc}") from exc - _, metrics = train_jax(cfg) - elif algo == "qtable": - _, metrics = train_qtable(cfg) - else: - _, metrics = train_sb3(cfg) - metrics["sweep/score"] = float( - metrics["eval/reward"] + float(cfg["revenue_weight"]) * metrics["eval/revenue"] - ) - return metrics - - -def run_wandb( - project: str, overrides: dict, mode: str = "online", sweep_mode: bool = False -) -> dict: - if not HAS_WANDB: - raise ImportError("wandb is required for sweep runs") - if not sweep_mode: - pre_cfg = _cfg(overrides) - if pre_cfg.get("use_jax"): - try: - import jax - - if jax.process_count() > 1 and jax.process_index() != 0: - return train_once(pre_cfg) - except Exception: - pass - init_kwargs = {"mode": mode} - if sweep_mode: - run = wandb.init(**init_kwargs) - else: - run = wandb.init(project=project, config=overrides, **init_kwargs) - - try: - cfg = _cfg(_wandb_cfg_dict()) - if sweep_mode: - for k, v in overrides.items(): - if k not in wandb.config: - cfg[k] = v - - metrics = train_once(cfg) - step = int(metrics.get("train/global_step", cfg["total_timesteps"])) - wandb.log(metrics, step=step) - for k, v in metrics.items(): - run.summary[k] = v - return metrics - finally: - if wandb.run is not None: - wandb.finish() - - -def run_local(overrides: dict) -> dict: - cfg = _cfg(overrides) - metrics = train_once(cfg) - should_print = True - if cfg.get("use_jax"): - try: - import jax - - should_print = jax.process_index() == 0 - except Exception: - should_print = True - if should_print: - print(json.dumps(metrics, indent=2)) - # sentinel line for machine-readable extraction; must stay on one line - print("PHANTOM_METRICS:" + json.dumps(metrics)) - return metrics - - -def main(): - p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps") - p.add_argument("--project", default=DEFAULT_CFG["project"]) - p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"]) - p.add_argument("--seed", type=int) - p.add_argument("--total-timesteps", type=int) - p.add_argument("--alpha", type=float) - p.add_argument("--N", type=int) - p.add_argument("--n-products", type=int) - p.add_argument("--lambda-coi", type=float) - p.add_argument("--info-value", type=float) - p.add_argument("--robust-radius", type=float) - p.add_argument("--robust-points", type=int) - p.add_argument("--no-robust", action="store_true") - p.add_argument("--learning-rate", type=float) - p.add_argument("--gamma", type=float) - p.add_argument("--gae-lambda", type=float) - p.add_argument("--clip-range", type=float) - p.add_argument("--ent-coef", type=float) - p.add_argument("--revenue-weight", type=float) - p.add_argument("--price-low", type=float) - p.add_argument("--price-high", type=float) - p.add_argument("--action-levels", type=int) - p.add_argument("--action-scale-low", type=float) - p.add_argument("--action-scale-high", type=float) - p.add_argument("--max-steps", type=int) - p.add_argument("--margin-floor", type=float) - p.add_argument("--margin-floor-patience", type=int) - p.add_argument("--arch", type=str) - p.add_argument("--activation", type=str) - p.add_argument("--jax", action="store_true") - p.add_argument("--jax-num-envs", type=int) - p.add_argument("--jax-num-steps", type=int) - p.add_argument("--jax-num-minibatches", type=int) - p.add_argument("--jax-update-epochs", type=int) - p.add_argument("--jax-anneal-lr", type=str) - p.add_argument("--checkpoint-interval", type=int) - p.add_argument("--sweep-agent", action="store_true") - p.add_argument("--sweep-id", type=str) - p.add_argument("--count", type=int, default=0) - p.add_argument("--offline", action="store_true") - p.add_argument("--no-wandb", action="store_true") - args = p.parse_args() + backend = None if args.backend == "auto" else args.backend overrides = { + "project": args.project, + "backend": backend, "algo": args.algo, "seed": args.seed, "total_timesteps": args.total_timesteps, + "model_dir": args.model_dir, + "log_freq": args.log_freq, + "checkpoint_interval": args.checkpoint_interval, + "device": args.device, "alpha": args.alpha, "N": args.N, "n_products": args.n_products, @@ -528,11 +150,6 @@ def main(): "robust_radius": args.robust_radius, "robust_points": args.robust_points, "no_robust": args.no_robust, - "learning_rate": args.learning_rate, - "gamma": args.gamma, - "gae_lambda": args.gae_lambda, - "clip_range": args.clip_range, - "ent_coef": args.ent_coef, "revenue_weight": args.revenue_weight, "price_low": args.price_low, "price_high": args.price_high, @@ -542,40 +159,87 @@ def main(): "max_steps": args.max_steps, "margin_floor": args.margin_floor, "margin_floor_patience": args.margin_floor_patience, + "learning_rate": args.learning_rate, + "gamma": args.gamma, + "buffer_size": args.buffer_size, + "batch_size": args.batch_size, + "tau": args.tau, + "train_freq": args.train_freq, + "learning_starts": args.learning_starts, + "target_update_interval": args.target_update_interval, + "exploration_fraction": args.exploration_fraction, + "exploration_final_eps": args.exploration_final_eps, + "n_steps": args.n_steps, + "n_epochs": args.n_epochs, + "gae_lambda": args.gae_lambda, + "clip_range": args.clip_range, + "ent_coef": args.ent_coef, + "q_lr": args.q_lr, + "q_bins": args.q_bins, + "eps_start": args.eps_start, + "eps_end": args.eps_end, + "eps_decay": args.eps_decay, "arch": args.arch, "activation": args.activation, - "use_jax": args.jax, + "vf_coef": args.vf_coef, + "max_grad_norm": args.max_grad_norm, + "eval_freq": args.eval_freq, + "eval_episodes": args.eval_episodes, + "use_jax": args.jax or None, "jax_num_envs": args.jax_num_envs, "jax_num_steps": args.jax_num_steps, "jax_num_minibatches": args.jax_num_minibatches, "jax_update_epochs": args.jax_update_epochs, - "checkpoint_interval": args.checkpoint_interval, - "jax_anneal_lr": _truthy(args.jax_anneal_lr) - if args.jax_anneal_lr is not None - else None, + "jax_anneal_lr": jax_anneal_lr, } - overrides = {k: v for k, v in overrides.items() if v is not None} + return {key: value for key, value in overrides.items() if value is not None} + + +def main(argv: list[str] | None = None) -> None: + import sys + + raw_args = list(sys.argv[1:] if argv is None else argv) + run_kind = _probe_run_kind(raw_args) + if run_kind == "benchmark": + run_benchmark_cli(_strip_run_kind(raw_args)) + return + + parser = _build_parser() + args, unknown = parser.parse_known_args(raw_args) + if unknown: + raise ValueError(f"Unknown arguments for training mode: {' '.join(unknown)}") + + overrides = _overrides_from_args(args) + scenario = str(args.scenario) + group = args.group + extra_tags = tuple(_parse_tags(args.tags)) if args.sweep_agent: - if args.no_wandb: - raise ValueError("sweep agent requires wandb") - if not args.sweep_id: - raise ValueError("--sweep-id is required with --sweep-agent") - mode = "offline" if args.offline else "online" - wandb.agent( - args.sweep_id, - function=lambda: run_wandb( - args.project, overrides, mode=mode, sweep_mode=True - ), - count=args.count if args.count > 0 else None, + run_sweep_agent( + project=args.project, + sweep_id=str(args.sweep_id or ""), + count=int(args.count), + offline=bool(args.offline), + no_wandb=bool(args.no_wandb), + base_overrides=overrides, + kind="sweep", + scenario=scenario, + group=group, + extra_tags=extra_tags, ) return - if args.no_wandb or not HAS_WANDB: - run_local(overrides) - return - - run_wandb(args.project, overrides, mode="offline" if args.offline else "online") + spec = TrainSpec.from_flat(overrides) + run_train_once( + spec, + project=args.project, + offline=bool(args.offline), + no_wandb=bool(args.no_wandb), + kind="train", + scenario=scenario, + group=group, + extra_tags=extra_tags, + ) if __name__ == "__main__": diff --git a/engine/train_core.py b/engine/train_core.py new file mode 100644 index 0000000..8b29f45 --- /dev/null +++ b/engine/train_core.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from .spec import TrainSpec +from .telemetry.metrics import canonicalize_metrics + + +@dataclass(frozen=True) +class TrainResult: + spec: TrainSpec + metrics: dict[str, Any] + artifacts: dict[str, str] + + +def run_train(spec: TrainSpec) -> TrainResult: + cfg = spec.to_flat_dict() + algo = spec.algorithm.name + + if spec.runtime.use_jax or spec.runtime.backend == "jax": + from .backends.jax import train_jax_backend + + _, raw_metrics = train_jax_backend(cfg) + elif algo == "qtable": + from .backends.qtable import train_qtable + + _, raw_metrics = train_qtable(cfg) + else: + from .backends.sb3 import train_sb3 + + _, raw_metrics = train_sb3(cfg) + + metrics = canonicalize_metrics(raw_metrics, spec) + artifacts: dict[str, str] = {} + model_path = raw_metrics.get("model/path") + if isinstance(model_path, str): + artifacts["model/path"] = model_path + + return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts) diff --git a/engine/wrapper.py b/engine/wrapper.py index 751c104..dc97cb1 100644 --- a/engine/wrapper.py +++ b/engine/wrapper.py @@ -381,7 +381,7 @@ if __name__ == "__main__": def predict(self, obs, **kwargs): return self.env.action_space.sample(), None - wandb.init(project="phantom-pricing", config={"policy": "random", "alpha": 0.3}) + wandb.init(project="capstone", config={"policy": "random", "alpha": 0.3}) env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None)) model = RandomPolicy(env) diff --git a/nx.json b/nx.json index a931e25..d286a8f 100644 --- a/nx.json +++ b/nx.json @@ -55,6 +55,9 @@ "train": { "cache": false }, + "benchmark": { + "cache": false + }, "up": { "cache": false }, diff --git a/package.json b/package.json index 6f5d85a..1ee68c6 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ "platform:down": "nx run platform:down", "platform:logs": "nx run platform:logs", "research:test": "nx run research:test", + "research:benchmark": "nx run research:benchmark", "e2e:test": "nx run e2e:test" }, "devDependencies": { diff --git a/scripts/nx_research.sh b/scripts/nx_research.sh index 846bd60..5e72e3f 100644 --- a/scripts/nx_research.sh +++ b/scripts/nx_research.sh @@ -30,10 +30,20 @@ case "$cmd" in load_sweep_env require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" WANDB_ENTITY="${WANDB_ENTITY:-}" \ - WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ + WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \ WANDB_API_KEY="$WANDB_API_KEY" \ .venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000} ;; + benchmark) + load_sweep_env + if [[ " ${LOCAL_BENCHMARK_ARGS:-} " != *" --no-wandb "* ]]; then + require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" + fi + WANDB_ENTITY="${WANDB_ENTITY:-}" \ + WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \ + WANDB_API_KEY="${WANDB_API_KEY:-}" \ + .venv/bin/python -m engine.train --run-kind benchmark ${LOCAL_BENCHMARK_ARGS:---tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu} + ;; train-agent) load_sweep_env require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" @@ -43,10 +53,23 @@ case "$cmd" in args+=(--count "$AGENT_COUNT") fi WANDB_ENTITY="${WANDB_ENTITY:-}" \ - WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ + WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \ WANDB_API_KEY="$WANDB_API_KEY" \ .venv/bin/python -m engine.train "${args[@]}" ;; + benchmark-agent) + load_sweep_env + require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" + require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" + args=(--sweep-agent --sweep-id "$SWEEP_ID") + if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then + args+=(--count "$AGENT_COUNT") + fi + WANDB_ENTITY="${WANDB_ENTITY:-}" \ + WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \ + WANDB_API_KEY="$WANDB_API_KEY" \ + .venv/bin/python -m engine.train --run-kind benchmark "${args[@]}" ${BENCHMARK_AGENT_ARGS:-} + ;; train-bootstrap) load_sweep_env require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" @@ -55,7 +78,7 @@ case "$cmd" in require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" WANDB_API_KEY="$WANDB_API_KEY" \ WANDB_ENTITY="${WANDB_ENTITY:-}" \ - WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ + WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \ GITHUB_TOKEN="$GITHUB_TOKEN" \ REPO_URL="$REPO_URL" \ BRANCH="${BRANCH:-main}" \ @@ -115,7 +138,7 @@ PY train-tpu-vm-sweep) load_sweep_env require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" - require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" + require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123" require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" args=( --sweep-id "$SWEEP_ID" diff --git a/scripts/tpu_vm_sweep_agent.py b/scripts/tpu_vm_sweep_agent.py index 83c16aa..b051b86 100644 --- a/scripts/tpu_vm_sweep_agent.py +++ b/scripts/tpu_vm_sweep_agent.py @@ -96,7 +96,11 @@ def _extract_metrics(output: str) -> dict: obj = json.loads(block) except Exception: continue - if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj): + if isinstance(obj, dict) and ( + "objective/score" in obj + or "eval/reward_mean" in obj + or "sweep/score" in obj + ): return obj return {}