refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -3,7 +3,7 @@
# Required for wandb runs and sweep agent workers. # Required for wandb runs and sweep agent workers.
WANDB_API_KEY= WANDB_API_KEY=
WANDB_ENTITY= WANDB_ENTITY=
WANDB_PROJECT=phantom-pricing WANDB_PROJECT=capstone
# Required for private repo bootstrap workers. # Required for private repo bootstrap workers.
GITHUB_TOKEN= GITHUB_TOKEN=
@@ -16,3 +16,7 @@ GITHUB_TOKEN=
# AGENT_COUNT=0 # AGENT_COUNT=0
# AGENT_LOOP=1 # AGENT_LOOP=1
# RETRY_SECONDS=20 # RETRY_SECONDS=20
# Optional local benchmark defaults.
# LOCAL_BENCHMARK_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu
# BENCHMARK_AGENT_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3,0.6 --episodes 5

1
.gitignore vendored
View File

@@ -68,6 +68,7 @@ sim/case/thesis_simplified/runs*/
# model binaries # model binaries
engine/models/*.zip engine/models/*.zip
engine/studies/results/*
*.zip *.zip
# wandb local state # wandb local state

View File

@@ -13,9 +13,11 @@ NX := npx nx
SWEEP_ENV_FILE ?= .env.sweep SWEEP_ENV_FILE ?= .env.sweep
WANDB_ENTITY ?= WANDB_ENTITY ?=
WANDB_PROJECT ?= phantom-pricing WANDB_PROJECT ?= capstone
SWEEP_ID ?= SWEEP_ID ?=
LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000 LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000
LOCAL_BENCHMARK_ARGS ?= --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu
BENCHMARK_AGENT_ARGS ?=
AGENT_COUNT ?= 0 AGENT_COUNT ?= 0
REPO_URL ?= REPO_URL ?=
@@ -36,7 +38,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
.PHONY: help .PHONY: help
help: help:
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish" @echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
@echo "" @echo ""
@echo "Build general public version:" @echo "Build general public version:"
@@ -45,6 +47,9 @@ help:
@echo "Local wandb run:" @echo "Local wandb run:"
@echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'" @echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'"
@echo "" @echo ""
@echo "Local benchmark run:"
@echo " make benchmark LOCAL_BENCHMARK_ARGS='--tiers static,surge,linear --alpha-values 0.0,0.3 --episodes 3 --no-wandb'"
@echo ""
@echo "Local sweep agent from this repo:" @echo "Local sweep agent from this repo:"
@echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5" @echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5"
@echo "" @echo ""
@@ -104,6 +109,14 @@ install:
train: train:
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train
.PHONY: benchmark
benchmark:
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_BENCHMARK_ARGS="$(LOCAL_BENCHMARK_ARGS)" $(NX) run research:benchmark
.PHONY: benchmark.agent
benchmark.agent:
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" BENCHMARK_AGENT_ARGS="$(BENCHMARK_AGENT_ARGS)" $(NX) run research:benchmark-agent
.PHONY: train.agent .PHONY: train.agent
train.agent: train.agent:
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-agent @WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-agent

View File

@@ -0,0 +1 @@
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"]

81
engine/backends/common.py Normal file
View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import Any, Mapping
import numpy as np
def make_env(cfg: Mapping[str, Any]):
from gymnasium.wrappers import FlattenObservation
from ..lib.wrappers import EconomicMetricsWrapper
from ..wrapper import PHANTOM
env = PHANTOM(
n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]),
N=int(cfg["N"]),
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
lambda_coi=float(cfg["lambda_coi"]),
robust_radius=float(cfg["robust_radius"]),
robust_points=int(cfg["robust_points"]),
info_value=float(cfg["info_value"]),
action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None,
)
env = EconomicMetricsWrapper(env)
return FlattenObservation(env)
def _action(agent: Any, obs: Any, deterministic: bool = True):
out = agent.predict(obs, deterministic=deterministic)
action = out[0] if isinstance(out, tuple) else out
if isinstance(action, np.ndarray) and action.size == 1:
return int(action.reshape(-1)[0])
return action
def evaluate(agent: Any, env: Any, episodes: int) -> dict[str, float]:
rewards: list[float] = []
revenues: list[float] = []
margins: list[float] = []
coi_levels: list[float] = []
for _ in range(int(episodes)):
obs, _ = env.reset()
done = False
ep_reward = 0.0
ep_revenue = 0.0
ep_margin = 0.0
ep_coi = 0.0
steps = 0
while not done:
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
done = bool(term or trunc)
econ = info.get("economics", {})
ep_reward += float(reward)
ep_revenue += float(econ.get("revenue", info.get("revenue", 0.0)))
ep_margin += float(econ.get("margin", 0.0))
ep_coi += float(econ.get("coi_level", 0.0))
steps += 1
rewards.append(ep_reward)
revenues.append(ep_revenue)
denom = max(steps, 1)
margins.append(ep_margin / denom)
coi_levels.append(ep_coi / denom)
return {
"eval/reward_mean": float(np.mean(rewards)) if rewards else 0.0,
"eval/reward_std": float(np.std(rewards)) if rewards else 0.0,
"eval/revenue_mean": float(np.mean(revenues)) if revenues else 0.0,
"eval/revenue_std": float(np.std(revenues)) if revenues else 0.0,
"eval/margin_mean": float(np.mean(margins)) if margins else 0.0,
"eval/coi_level_mean": float(np.mean(coi_levels)) if coi_levels else 0.0,
}

18
engine/backends/jax.py Normal file
View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any, Mapping
from ..jax import JAX_AVAILABLE
def train_jax_backend(
cfg: Mapping[str, Any],
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
from ..jax.train import train_jax
return train_jax(dict(cfg))

53
engine/backends/qtable.py Normal file
View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any, Mapping
import numpy as np
from .common import evaluate, make_env
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
from ..lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
env = make_env(cfg)
eval_env = make_env(cfg)
agent = EventQTable(
env.action_space.n,
int(cfg["n_products"]),
(float(cfg["price_low"]), float(cfg["price_high"])),
lr=float(cfg["q_lr"]),
gamma=float(cfg["gamma"]),
n_bins=int(cfg["q_bins"]),
)
total_reward = 0.0
total_revenue = 0.0
steps = 0
epsilon = float(cfg["eps_start"])
obs, _ = env.reset(seed=int(cfg["seed"]))
for _ in range(int(cfg["total_timesteps"])):
action, state = agent.act(obs, epsilon)
nxt, reward, term, trunc, info = env.step(action)
done = bool(term or trunc)
agent.update(state, action, float(reward), agent.encode(nxt), done)
total_reward += float(reward)
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
steps += 1
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
obs = env.reset()[0] if done else nxt
metrics: dict[str, float | int] = {
"train/reward_mean": total_reward / max(steps, 1),
"train/revenue_mean": total_revenue / max(steps, 1),
"train/epsilon": float(epsilon),
"train/global_step": int(cfg["total_timesteps"]),
}
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
env.close()
eval_env.close()
return agent, metrics

228
engine/backends/sb3.py Normal file
View File

@@ -0,0 +1,228 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
from ..telemetry.wandb import get_wandb_module
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
from .common import evaluate, make_env
def _net_arch(name: Any) -> list[int]:
presets = {
"tiny": [32, 32],
"small": [64, 64],
"medium": [128, 128],
"large": [256, 256],
}
if isinstance(name, (list, tuple)):
return [int(v) for v in name]
raw = str(name).lower().strip()
if raw in presets:
return presets[raw]
if "x" in raw:
try:
parsed = [int(v) for v in raw.split("x") if v]
return parsed if parsed else presets["small"]
except ValueError:
return presets["small"]
return presets["small"]
def _activation(name: Any):
try:
import torch.nn as nn
except ImportError:
return None
return {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
}.get(str(name).lower().strip(), nn.ReLU)
def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
kwargs: dict[str, Any] = {"net_arch": _net_arch(cfg.get("arch", "small"))}
activation = _activation(cfg.get("activation", "relu"))
if activation is not None:
kwargs["activation_fn"] = activation
return kwargs
def _sb3_model_cls(algo: str):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def build_model(cfg: Mapping[str, Any], env: Any):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
algo = str(cfg["algo"])
policy_kwargs = _policy_kwargs(cfg)
device = str(cfg.get("device", "auto"))
seed = int(cfg["seed"])
if algo == "sac":
raise ValueError("sac is not supported with the discrete core env")
if algo == "ppo":
return PPO(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=int(cfg["n_steps"]),
batch_size=int(cfg["batch_size"]),
n_epochs=int(cfg["n_epochs"]),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
clip_range=float(cfg["clip_range"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "a2c":
return A2C(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=max(5, int(cfg["n_steps"]) // 32),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "dqn":
return DQN(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
buffer_size=int(cfg["buffer_size"]),
batch_size=int(cfg["batch_size"]),
gamma=float(cfg["gamma"]),
train_freq=int(cfg["train_freq"]),
learning_starts=int(cfg["learning_starts"]),
target_update_interval=int(cfg["target_update_interval"]),
exploration_fraction=float(cfg["exploration_fraction"]),
exploration_final_eps=float(cfg["exploration_final_eps"]),
)
raise ValueError(f"unsupported algo '{algo}'")
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return model
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is None:
return model
checkpoint_path, metadata = restored
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
checkpoint_path.as_posix(),
env=env,
)
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
return resumed
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc
env = Monitor(make_env(cfg))
eval_env = Monitor(make_env(cfg))
model = build_model(cfg, env)
try:
import torch
print(
"PHANTOM_DEVICE: "
+ json.dumps(
{
"requested": str(cfg.get("device", "auto")),
"torch_cuda_available": bool(torch.cuda.is_available()),
"torch_device_count": int(torch.cuda.device_count()),
"sb3_device": str(getattr(model, "device", "unknown")),
}
)
)
except Exception:
pass
model = _maybe_resume_model(cfg, env, model)
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
callbacks.append(
CheckpointArtifactCallback(
dict(cfg),
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
callbacks.append(
EvalCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
deterministic=True,
verbose=0,
)
)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=callbacks,
reset_num_timesteps=False,
)
model_dir = Path(str(cfg["model_dir"]))
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
metrics: dict[str, float | int | str] = evaluate(
model,
eval_env,
int(cfg["eval_episodes"]),
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
env.close()
eval_env.close()
return model, metrics

456
engine/benchmark.py Normal file
View File

@@ -0,0 +1,456 @@
from __future__ import annotations
import argparse
import json
import os
from datetime import datetime, UTC
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from .lib.tiers import LinearElasticityPolicy, StaticPolicy, SurgePolicy
from .spec import TrainSpec
from .telemetry.wandb import get_wandb_module
wandb = get_wandb_module()
HAS_WANDB = wandb is not None
def _parse_list(raw: str) -> list[str]:
return [x.strip().lower() for x in str(raw).split(",") if x.strip()]
def _parse_float_list(raw: str) -> list[float]:
return [float(x.strip()) for x in str(raw).split(",") if x.strip()]
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _action(policy, obs: np.ndarray):
out = policy.predict(obs, deterministic=True)
action = out[0] if isinstance(out, tuple) else out
if isinstance(action, np.ndarray) and action.size == 1:
return int(action.reshape(-1)[0])
return int(action)
def _run_eval_episode(env, policy) -> dict:
obs, _ = env.reset()
done = False
total_reward = 0.0
total_revenue = 0.0
total_margin = 0.0
total_coi = 0.0
price_trace: list[float] = []
step_count = 0
while not done:
action = _action(policy, obs)
obs, reward, term, trunc, info = env.step(action)
done = bool(term or trunc)
econ = info.get("economics", {})
total_reward += float(reward)
total_revenue += float(econ.get("revenue", 0.0))
total_margin += float(econ.get("margin", 0.0))
total_coi += float(econ.get("coi_level", 0.0))
prices = np.asarray(info.get("prices", []), dtype=np.float32)
if prices.size > 0:
price_trace.append(float(np.mean(prices)))
step_count += 1
denom = max(step_count, 1)
return {
"reward": total_reward,
"revenue": total_revenue,
"mean_margin": total_margin / denom,
"mean_coi": total_coi / denom,
"price_trace": price_trace,
}
def _build_tier(name: str, cfg: dict, alpha: float):
from .backends.common import make_env
from .backends.qtable import train_qtable
from .backends.sb3 import train_sb3
tier = name.lower().strip()
run_cfg = dict(cfg)
run_cfg["alpha"] = float(alpha)
if tier == "static":
return StaticPolicy(int(run_cfg["action_levels"]))
if tier == "surge":
return SurgePolicy(
n_actions=int(run_cfg["action_levels"]),
n_products=int(run_cfg["n_products"]),
)
if tier == "linear":
warmup_env = make_env(run_cfg)
policy = LinearElasticityPolicy(
n_actions=int(run_cfg["action_levels"]),
n_products=int(run_cfg["n_products"]),
price_low=float(run_cfg["price_low"]),
price_high=float(run_cfg["price_high"]),
)
policy.fit(
warmup_env,
warmup_steps=int(run_cfg.get("linear_warmup_steps", 800)),
seed=int(run_cfg["seed"]),
)
warmup_env.close()
return policy
if tier == "qtable":
agent, _ = train_qtable(run_cfg)
return agent
if tier in {"ppo", "a2c", "dqn"}:
run_cfg["algo"] = tier
agent, _ = train_sb3(run_cfg)
return agent
raise ValueError(f"unsupported tier '{name}'")
def run_benchmark(
cfg: dict, tiers: list[str], alpha_values: list[float], n_episodes: int
):
from .backends.common import make_env
rows: list[dict] = []
traces: list[dict] = []
for alpha in alpha_values:
for tier_name in tiers:
policy = _build_tier(tier_name, cfg, alpha)
env = make_env({**cfg, "alpha": float(alpha)})
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
env.close()
row = {
"tier": tier_name,
"alpha": float(alpha),
"episodes": int(n_episodes),
"mean_reward": float(np.mean([e["reward"] for e in eps])),
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
"mean_margin": float(np.mean([e["mean_margin"] for e in eps])),
"mean_coi": float(np.mean([e["mean_coi"] for e in eps])),
"std_revenue": float(np.std([e["revenue"] for e in eps])),
}
row["objective_score"] = (
row["mean_reward"]
+ float(cfg.get("revenue_weight", 0.01)) * row["mean_revenue"]
)
rows.append(row)
max_len = max((len(e["price_trace"]) for e in eps), default=0)
step_means = []
for step in range(max_len):
vals = [
e["price_trace"][step] for e in eps if step < len(e["price_trace"])
]
step_means.append(float(np.mean(vals)) if vals else np.nan)
traces.append(
{
"tier": tier_name,
"alpha": float(alpha),
"mean_price_trace": step_means,
}
)
if HAS_WANDB and wandb.run is not None:
wandb.log(
{
"study/alpha": float(alpha),
"eval/reward_mean": row["mean_reward"],
"eval/revenue_mean": row["mean_revenue"],
"eval/margin_mean": row["mean_margin"],
"objective/score": row["objective_score"],
"objective/coi_preserved": row["mean_coi"],
}
)
return pd.DataFrame(rows), traces
def _plot_outputs(df: pd.DataFrame, traces: list[dict], out_dir: Path, stamp: str):
fig1 = plt.figure(figsize=(11, 4.5))
if "mode" in df.columns:
groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist())
for tier, mode in groups:
sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha")
plt.plot(
sub["alpha"],
sub["mean_revenue"],
marker="o",
label=f"{tier}:{mode}",
)
else:
for tier in sorted(df["tier"].unique()):
sub = df[df["tier"] == tier].sort_values("alpha")
plt.plot(sub["alpha"], sub["mean_revenue"], marker="o", label=tier)
plt.xlabel("contamination alpha")
plt.ylabel("mean episode revenue")
plt.title("Revenue under contamination")
plt.grid(alpha=0.3)
plt.legend()
fig1.tight_layout()
rev_path = out_dir / f"benchmark_revenue_{stamp}.png"
fig1.savefig(rev_path, dpi=220)
plt.close(fig1)
fig2 = plt.figure(figsize=(11, 4.5))
if "mode" in df.columns:
groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist())
for tier, mode in groups:
sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha")
plt.plot(
sub["alpha"],
sub["mean_coi"],
marker="s",
label=f"{tier}:{mode}",
)
else:
for tier in sorted(df["tier"].unique()):
sub = df[df["tier"] == tier].sort_values("alpha")
plt.plot(sub["alpha"], sub["mean_coi"], marker="s", label=tier)
plt.xlabel("contamination alpha")
plt.ylabel("mean COI level")
plt.title("COI preservation")
plt.grid(alpha=0.3)
plt.legend()
fig2.tight_layout()
coi_path = out_dir / f"benchmark_coi_{stamp}.png"
fig2.savefig(coi_path, dpi=220)
plt.close(fig2)
focus_alpha = float(df["alpha"].min()) if not df.empty else 0.0
alpha_traces = [t for t in traces if abs(float(t["alpha"]) - focus_alpha) < 1e-9]
fig3 = plt.figure(figsize=(11, 4.5))
for item in alpha_traces:
xs = np.arange(len(item["mean_price_trace"]))
ys = np.asarray(item["mean_price_trace"], dtype=np.float32)
mode = item.get("mode")
label = f"{item['tier']}:{mode}" if mode is not None else str(item["tier"])
plt.plot(xs, ys, label=label)
plt.xlabel("step")
plt.ylabel("mean price")
plt.title(f"Price evolution (alpha={focus_alpha:.2f})")
plt.grid(alpha=0.3)
plt.legend()
fig3.tight_layout()
price_path = out_dir / f"benchmark_price_trace_{stamp}.png"
fig3.savefig(price_path, dpi=220)
plt.close(fig3)
return rev_path, coi_path, price_path
def _run_with_args(args):
compare_robust = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
robust_modes = [False, True] if compare_robust else [bool(args.no_robust)]
base_overrides = {
"seed": args.seed,
"total_timesteps": args.total_timesteps,
"n_products": args.n_products,
"N": args.N,
"lambda_coi": args.lambda_coi,
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"price_low": args.price_low,
"price_high": args.price_high,
"action_levels": args.action_levels,
"action_scale_low": args.action_scale_low,
"action_scale_high": args.action_scale_high,
"max_steps": args.max_steps,
"learning_rate": args.learning_rate,
"batch_size": args.batch_size,
"n_steps": args.n_steps,
"linear_warmup_steps": args.linear_warmup_steps,
"device": args.device,
}
tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values)
all_frames: list[pd.DataFrame] = []
all_traces: list[dict] = []
for no_robust in robust_modes:
overrides = dict(base_overrides)
overrides["no_robust"] = bool(no_robust)
cfg = TrainSpec.from_flat(
{k: v for k, v in overrides.items() if v is not None}
).to_flat_dict()
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
df_mode, traces_mode = run_benchmark(cfg, tiers, alpha_values, args.episodes)
mode_label = "no_robust" if no_robust else "robust"
df_mode["mode"] = mode_label
for trace in traces_mode:
trace["mode"] = mode_label
all_frames.append(df_mode)
all_traces.extend(traces_mode)
df = pd.concat(all_frames, ignore_index=True) if all_frames else pd.DataFrame()
traces = all_traces
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
stamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
csv_path = out_dir / f"benchmark_{stamp}.csv"
trace_path = out_dir / f"benchmark_traces_{stamp}.json"
df.to_csv(csv_path, index=False)
trace_path.write_text(json.dumps(traces, indent=2))
rev_path, coi_path, price_path = _plot_outputs(df, traces, out_dir, stamp)
if not df.empty:
best_idx = int(df["mean_revenue"].idxmax())
best = df.iloc[best_idx]
print(
"BEST_TIER="
+ json.dumps(
{
"tier": best["tier"],
"mode": best.get("mode", "robust"),
"alpha": float(best["alpha"]),
"mean_revenue": float(best["mean_revenue"]),
"mean_coi": float(best["mean_coi"]),
}
)
)
print(f"BENCHMARK_CSV={csv_path}")
print(f"BENCHMARK_TRACES={trace_path}")
print(f"BENCHMARK_PLOT_REVENUE={rev_path}")
print(f"BENCHMARK_PLOT_COI={coi_path}")
print(f"BENCHMARK_PLOT_PRICE={price_path}")
def run_cli(raw_args: list[str] | None = None):
parser = argparse.ArgumentParser(description="PHANTOM benchmark orchestrator")
parser.add_argument("--project", default="capstone")
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")
parser.add_argument("--alpha-values", default="0.0,0.3,0.6")
parser.add_argument("--episodes", type=int, default=10)
parser.add_argument("--output-dir", default="engine/studies/results")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--total-timesteps", type=int, default=25_000)
parser.add_argument("--n-products", type=int, default=10)
parser.add_argument("--N", type=int, default=100)
parser.add_argument("--lambda-coi", type=float, default=0.2)
parser.add_argument("--robust-radius", type=float, default=0.15)
parser.add_argument("--robust-points", type=int, default=5)
parser.add_argument("--price-low", type=float, default=10.0)
parser.add_argument("--price-high", type=float, default=150.0)
parser.add_argument("--action-levels", type=int, default=9)
parser.add_argument("--action-scale-low", type=float, default=0.8)
parser.add_argument("--action-scale-high", type=float, default=1.2)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-steps", type=int, default=2048)
parser.add_argument("--linear-warmup-steps", type=int, default=800)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--no-robust", action="store_true")
parser.add_argument("--no-wandb", action="store_true")
parser.add_argument("--offline", action="store_true")
parser.add_argument("--sweep-agent", action="store_true")
parser.add_argument("--sweep-id", type=str)
parser.add_argument("--count", type=int, default=0)
args = parser.parse_args(raw_args)
if args.sweep_agent:
if args.no_wandb or not HAS_WANDB:
raise ValueError("sweep agent requires wandb")
if not args.sweep_id:
raise ValueError("--sweep-id is required with --sweep-agent")
def _sweep_run():
run = wandb.init(mode="offline" if args.offline else "online")
try:
key_to_attr = {
"tiers": "tiers",
"alpha_values": "alpha_values",
"episodes": "episodes",
"total_timesteps": "total_timesteps",
"lambda_coi": "lambda_coi",
"robust_radius": "robust_radius",
"robust_points": "robust_points",
"learning_rate": "learning_rate",
"batch_size": "batch_size",
"n_steps": "n_steps",
"no_robust": "no_robust",
"device": "device",
}
for key in (
"tiers",
"alpha_values",
"episodes",
"total_timesteps",
"lambda_coi",
"robust_radius",
"robust_points",
"learning_rate",
"batch_size",
"n_steps",
"no_robust",
"device",
):
if key in wandb.config:
setattr(args, key_to_attr[key], wandb.config[key])
_run_with_args(args)
finally:
if run is not None:
wandb.finish()
wandb.agent(
args.sweep_id,
function=_sweep_run,
count=args.count if args.count > 0 else None,
)
return
if args.no_wandb or not HAS_WANDB:
_run_with_args(args)
return
run = wandb.init(
project=args.project,
name=f"benchmark-{datetime.now(UTC).strftime('%m%d-%H%M%S')}",
tags=[
"benchmark",
"robust-compare"
if _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
else "single-mode",
],
config={
"run.kind": "benchmark",
"tiers": args.tiers,
"alpha_values": args.alpha_values,
"episodes": args.episodes,
"total_timesteps": args.total_timesteps,
"lambda_coi": args.lambda_coi,
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"learning_rate": args.learning_rate,
"device": args.device,
},
mode="offline" if args.offline else "online",
)
try:
_run_with_args(args)
finally:
if run is not None:
wandb.finish()
if __name__ == "__main__":
run_cli()

View File

@@ -624,8 +624,8 @@ def evaluate_policy(
revenues.append(ep_revenue) revenues.append(ep_revenue)
return { return {
"eval/reward": float(np.mean(rewards)), "eval/reward_mean": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)), "eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)), "eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)), "eval/revenue_std": float(np.std(revenues)),
} }
@@ -665,8 +665,8 @@ def _evaluate_q_network(
revenues.append(ep_revenue) revenues.append(ep_revenue)
return { return {
"eval/reward": float(np.mean(rewards)), "eval/reward_mean": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)), "eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)), "eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)), "eval/revenue_std": float(np.std(revenues)),
} }
@@ -713,8 +713,8 @@ def _evaluate_q_table(
revenues.append(ep_revenue) revenues.append(ep_revenue)
return { return {
"eval/reward": float(np.mean(rewards)), "eval/reward_mean": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)), "eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)), "eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)), "eval/revenue_std": float(np.std(revenues)),
} }
@@ -831,8 +831,8 @@ def _train_actor_critic(
if is_primary and HAS_WANDB and wandb.run is not None: if is_primary and HAS_WANDB and wandb.run is not None:
wandb.log( wandb.log(
{ {
"train/reward": float(segment_values["reward"].mean()), "train/reward_mean": float(segment_values["reward"].mean()),
"train/revenue": float(segment_values["revenue"].mean()), "train/revenue_mean": float(segment_values["revenue"].mean()),
"train/agent_prob": float(segment_values["agent_prob"].mean()), "train/agent_prob": float(segment_values["agent_prob"].mean()),
"train/alpha_adv": float(segment_values["alpha_adv"].mean()), "train/alpha_adv": float(segment_values["alpha_adv"].mean()),
"train/coi_leakage": float(segment_values["coi_leakage"].mean()), "train/coi_leakage": float(segment_values["coi_leakage"].mean()),
@@ -873,8 +873,8 @@ def _train_actor_critic(
train_state = final_runner[0] train_state = final_runner[0]
denom = float(metric_count) if metric_count > 0 else 1.0 denom = float(metric_count) if metric_count > 0 else 1.0
metrics = { metrics = {
"train/reward": float(metric_sums["reward"] / denom), "train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom), "train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
@@ -1052,14 +1052,14 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
): ):
wandb.log( wandb.log(
{ {
"train/reward": metric_sums["reward"] / max(metric_count, 1), "train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
"train/revenue": metric_sums["revenue"] / max(metric_count, 1), "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
"train/agent_prob": metric_sums["agent_prob"] "train/agent_prob": metric_sums["agent_prob"]
/ max(metric_count, 1), / max(metric_count, 1),
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
"train/coi_leakage": metric_sums["coi_leakage"] "train/coi_leakage": metric_sums["coi_leakage"]
/ max(metric_count, 1), / max(metric_count, 1),
"train/dqn_loss": metric_sums["loss"] / max(loss_count, 1), "train/loss": metric_sums["loss"] / max(loss_count, 1),
"train/epsilon": epsilon_value, "train/epsilon": epsilon_value,
"train/global_step": global_step, "train/global_step": global_step,
}, },
@@ -1090,12 +1090,12 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
denom = float(metric_count) if metric_count > 0 else 1.0 denom = float(metric_count) if metric_count > 0 else 1.0
metrics = { metrics = {
"train/reward": float(metric_sums["reward"] / denom), "train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom), "train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
"train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)), "train/loss": float(metric_sums["loss"] / max(loss_count, 1)),
"train/global_step": total_steps, "train/global_step": total_steps,
} }
@@ -1236,8 +1236,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
): ):
wandb.log( wandb.log(
{ {
"train/reward": metric_sums["reward"] / max(metric_count, 1), "train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
"train/revenue": metric_sums["revenue"] / max(metric_count, 1), "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
"train/agent_prob": metric_sums["agent_prob"] "train/agent_prob": metric_sums["agent_prob"]
/ max(metric_count, 1), / max(metric_count, 1),
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
@@ -1269,8 +1269,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
denom = float(metric_count) if metric_count > 0 else 1.0 denom = float(metric_count) if metric_count > 0 else 1.0
metrics = { metrics = {
"train/reward": float(metric_sums["reward"] / denom), "train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom), "train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom),

View File

@@ -1,38 +1,39 @@
from .demand import estimate_demand, estimate_weighted_demand, generate_demand_for_actor from __future__ import annotations
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
from .render import DashboardRenderer, style_axis
from .wrappers import EconomicMetricsWrapper
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
from .providers import (
ProviderBenchmark,
ProviderResult,
BenchmarkConfig,
RandomBaseline,
SurgeBaseline,
)
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
from .discrete import EventQTable
__all__ = [ from importlib import import_module
"estimate_demand",
"estimate_weighted_demand", _EXPORTS: dict[str, tuple[str, str]] = {
"generate_demand_for_actor", "estimate_demand": (".demand", "estimate_demand"),
"sample_behavior", "estimate_weighted_demand": (".demand", "estimate_weighted_demand"),
"get_transition_models", "generate_demand_for_actor": (".demand", "generate_demand_for_actor"),
"trajectory_to_events", "sample_behavior": (".behavior", "sample_behavior"),
"DashboardRenderer", "get_transition_models": (".behavior", "get_transition_models"),
"style_axis", "trajectory_to_events": (".behavior", "trajectory_to_events"),
"EconomicMetricsWrapper", "DashboardRenderer": (".render", "DashboardRenderer"),
"MetricsCallback", "style_axis": (".render", "style_axis"),
"EvalMetricsCallback", "EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
"CheckpointArtifactCallback", "MetricsCallback": (".callbacks", "MetricsCallback"),
"ProviderBenchmark", "EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
"ProviderResult", "CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
"BenchmarkConfig", "ProviderBenchmark": (".providers", "ProviderBenchmark"),
"RandomBaseline", "ProviderResult": (".providers", "ProviderResult"),
"SurgeBaseline", "BenchmarkConfig": (".providers", "BenchmarkConfig"),
"compute_uplift_coi", "RandomBaseline": (".providers", "RandomBaseline"),
"extract_purchases", "SurgeBaseline": (".providers", "SurgeBaseline"),
"compute_agent_probability", "compute_uplift_coi": (".coi", "compute_uplift_coi"),
"EventQTable", "extract_purchases": (".coi", "extract_purchases"),
] "compute_agent_probability": (".coi", "compute_agent_probability"),
"EventQTable": (".discrete", "EventQTable"),
}
__all__ = sorted(_EXPORTS)
def __getattr__(name: str):
if name not in _EXPORTS:
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
module_name, attr_name = _EXPORTS[name]
module = import_module(module_name, package=__name__)
value = getattr(module, attr_name)
globals()[name] = value
return value

View File

@@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback):
t = self.num_timesteps t = self.num_timesteps
payload = { payload = {
"economics/revenue": econ["revenue"], "train/revenue_step": econ["revenue"],
"economics/margin": econ["margin"], "train/margin_step": econ["margin"],
"coi/level": econ["coi_level"], "train/coi_level": econ["coi_level"],
"economics/regret": econ["regret"], "train/regret_step": econ["regret"],
} }
if "coi_mix" in econ: if "coi_mix" in econ:
payload["coi/mix"] = econ["coi_mix"] payload["train/coi_mix"] = econ["coi_mix"]
if "coi_base" in econ: if "coi_base" in econ:
payload["coi/base"] = econ["coi_base"] payload["train/coi_base"] = econ["coi_base"]
if "coi_leakage" in econ: if "coi_leakage" in econ:
payload["coi/leakage"] = econ["coi_leakage"] payload["train/coi_leakage"] = econ["coi_leakage"]
if "coi_penalty" in econ: if "coi_penalty" in econ:
payload["coi/penalty"] = econ["coi_penalty"] payload["train/coi_penalty"] = econ["coi_penalty"]
wandb.log(payload, step=t) wandb.log(payload, step=t)
self._episode_revenues.append(econ["revenue"]) self._episode_revenues.append(econ["revenue"])
@@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback):
return return
wandb.log( wandb.log(
{ {
"episode/mean_revenue": np.mean(self._episode_revenues), "train/revenue_rollout_mean": np.mean(self._episode_revenues),
"episode/total_revenue": np.sum(self._episode_revenues), "train/revenue_rollout_total": np.sum(self._episode_revenues),
}, },
step=self.num_timesteps, step=self.num_timesteps,
) )
@@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback):
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
wandb.log( wandb.log(
{ {
"eval/mean_reward": self.last_mean_reward, "eval/reward_mean": self.last_mean_reward,
"eval/mean_revenue": np.mean(self._eval_revenues) "eval/revenue_mean": np.mean(self._eval_revenues)
if self._eval_revenues if self._eval_revenues
else 0, else 0,
}, },

101
engine/lib/tiers.py Normal file
View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
import numpy as np
class PolicyLike(Protocol):
def predict(self, obs: np.ndarray, deterministic: bool = True): ...
class StaticPolicy:
def __init__(self, n_actions: int):
self._action = int(max(0, n_actions // 2))
def predict(self, obs: np.ndarray, deterministic: bool = True):
return self._action, None
class SurgePolicy:
def __init__(
self,
n_actions: int,
n_products: int,
high_threshold: float = 60.0,
low_threshold: float = 30.0,
):
self.n_actions = int(n_actions)
self.n_products = int(n_products)
self.mid = self.n_actions // 2
self.high_t = float(high_threshold)
self.low_t = float(low_threshold)
def predict(self, obs: np.ndarray, deterministic: bool = True):
obs_arr = np.asarray(obs, dtype=np.float32)
demand = obs_arr[: self.n_products]
demand_mean = float(np.mean(demand)) if demand.size > 0 else 0.0
if demand_mean >= self.high_t:
return min(self.mid + 2, self.n_actions - 1), None
if demand_mean <= self.low_t:
return max(self.mid - 2, 0), None
return self.mid, None
@dataclass
class LinearElasticityPolicy:
n_actions: int
n_products: int
price_low: float
price_high: float
def __post_init__(self):
self.n_actions = int(self.n_actions)
self.n_products = int(self.n_products)
self.price_low = float(self.price_low)
self.price_high = float(self.price_high)
self._target_price = 0.5 * (self.price_low + self.price_high)
self._action_scales = np.linspace(0.8, 1.2, self.n_actions)
def fit(self, env, warmup_steps: int = 800, seed: int = 42):
rng = np.random.default_rng(int(seed))
obs, _ = env.reset(seed=int(seed))
prices: list[float] = []
demands: list[float] = []
for _ in range(int(max(10, warmup_steps))):
action = int(rng.integers(0, self.n_actions))
obs, _, term, trunc, info = env.step(action)
done = bool(term or trunc)
p = np.asarray(info.get("prices", []), dtype=np.float32)
d = np.asarray(info.get("demand", []), dtype=np.float32)
if p.size > 0 and d.size > 0:
prices.append(float(np.mean(p)))
demands.append(float(np.mean(d)))
if done:
obs, _ = env.reset()
if len(prices) < 8:
self._target_price = 0.5 * (self.price_low + self.price_high)
return self
slope, intercept = np.polyfit(np.asarray(prices), np.asarray(demands), 1)
if slope < -1e-6:
p_star = -intercept / (2.0 * slope)
self._target_price = float(np.clip(p_star, self.price_low, self.price_high))
else:
self._target_price = 0.5 * (self.price_low + self.price_high)
return self
def predict(self, obs: np.ndarray, deterministic: bool = True):
obs_arr = np.asarray(obs, dtype=np.float32)
cur_prices = obs_arr[self.n_products : 2 * self.n_products]
cur_mean = (
float(np.mean(cur_prices)) if cur_prices.size > 0 else self._target_price
)
scale = self._target_price / max(cur_mean, 1e-6)
action = int(np.argmin(np.abs(self._action_scales - scale)))
return int(np.clip(action, 0, self.n_actions - 1)), None

View File

@@ -0,0 +1,5 @@
from .benchmark import run_benchmark_cli
from .sweep_agent import run_sweep_agent
from .train import run_train_once
__all__ = ["run_benchmark_cli", "run_sweep_agent", "run_train_once"]

View File

@@ -0,0 +1,7 @@
from __future__ import annotations
def run_benchmark_cli(raw_args: list[str] | None = None) -> None:
from ..benchmark import run_cli
run_cli(raw_args)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import Any, Mapping, Sequence
from ..spec import TrainSpec, run_name
from ..telemetry.wandb import (
current_config,
finish_run,
get_wandb_module,
init_run,
run_agent,
)
from .train import run_with_active_sweep_run
def run_sweep_agent(
*,
project: str,
sweep_id: str,
count: int,
offline: bool,
no_wandb: bool,
base_overrides: Mapping[str, Any],
kind: str,
scenario: str,
group: str | None,
extra_tags: Sequence[str],
) -> None:
if no_wandb:
raise ValueError("sweep agent requires wandb")
if not sweep_id:
raise ValueError("--sweep-id is required with --sweep-agent")
if get_wandb_module() is None:
raise ImportError("wandb is required for sweep runs")
mode = "offline" if offline else "online"
def _sweep_trial() -> None:
run = init_run(mode=mode, project=project, group=group, sweep_mode=True)
try:
merged = dict(base_overrides)
merged.update(current_config())
spec = TrainSpec.from_flat(merged)
if run is not None:
run.name = run_name(spec, kind=kind, scenario=scenario)
run_with_active_sweep_run(
spec,
kind=kind,
scenario=scenario,
group=group,
extra_tags=extra_tags,
)
finally:
finish_run()
run_agent(
sweep_id,
_sweep_trial,
count=count if count > 0 else None,
)

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import json
from typing import Any, Sequence
from ..spec import TrainSpec, run_metadata, run_name
from ..telemetry.wandb import (
finish_run,
get_wandb_module,
init_run,
log_metrics,
update_run_config,
update_summary,
)
from ..train_core import run_train
def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list[str]:
tags = [
kind,
spec.algorithm.name,
spec.runtime.backend,
"vanilla" if spec.study.no_robust else "robust",
]
tags.extend([tag for tag in extra_tags if tag])
return tags
def _print_local_metrics(metrics: dict[str, Any]) -> None:
print(json.dumps(metrics, indent=2))
print("PHANTOM_METRICS:" + json.dumps(metrics))
def _should_print_local(spec: TrainSpec) -> bool:
if not spec.runtime.use_jax:
return True
try:
import jax
return int(jax.process_index()) == 0
except Exception:
return True
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
if not spec.runtime.use_jax:
return False
try:
import jax
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
except Exception:
return False
def run_train_once(
spec: TrainSpec,
*,
project: str,
offline: bool,
no_wandb: bool,
kind: str,
scenario: str,
group: str | None,
extra_tags: Sequence[str],
) -> dict[str, Any]:
wandb = get_wandb_module()
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
result = run_train(spec)
if _should_print_local(spec):
_print_local_metrics(result.metrics)
return result.metrics
mode = "offline" if offline else "online"
tags = _tags_for_run(spec, kind, extra_tags)
metadata = run_metadata(
spec,
kind=kind,
scenario=scenario,
group=group,
tags=tags,
)
config = spec.to_flat_dict()
config.update(metadata)
name = run_name(spec, kind=kind, scenario=scenario)
init_run(
mode=mode,
project=project,
config=config,
name=name,
tags=tags,
group=group,
sweep_mode=False,
)
try:
result = run_train(spec)
metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step)
update_summary(metrics)
return metrics
finally:
finish_run()
def run_with_active_sweep_run(
spec: TrainSpec,
*,
kind: str,
scenario: str,
group: str | None,
extra_tags: Sequence[str],
) -> dict[str, Any]:
tags = _tags_for_run(spec, kind, extra_tags)
metadata = run_metadata(
spec,
kind=kind,
scenario=scenario,
group=group,
tags=tags,
)
update_run_config({**spec.to_flat_dict(), **metadata})
result = run_train(spec)
metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step)
update_summary(metrics)
return metrics

View File

@@ -31,6 +31,26 @@
"cwd": "." "cwd": "."
} }
}, },
"benchmark": {
"executor": "nx:run-commands",
"dependsOn": [
"install"
],
"options": {
"command": "bash scripts/nx_research.sh benchmark",
"cwd": "."
}
},
"benchmark-agent": {
"executor": "nx:run-commands",
"dependsOn": [
"install"
],
"options": {
"command": "bash scripts/nx_research.sh benchmark-agent",
"cwd": "."
}
},
"train-agent": { "train-agent": {
"executor": "nx:run-commands", "executor": "nx:run-commands",
"dependsOn": [ "dependsOn": [

340
engine/spec.py Normal file
View File

@@ -0,0 +1,340 @@
from __future__ import annotations
from dataclasses import dataclass, field
import os
from typing import Any, Mapping, Sequence
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
alias_map = {
"algorithm": "algo",
"algorithm.name": "algo",
"env.n_products": "n_products",
"env.action_levels": "action_levels",
"env.action_scale_low": "action_scale_low",
"env.action_scale_high": "action_scale_high",
"env.price_low": "price_low",
"env.price_high": "price_high",
"env.max_steps": "max_steps",
"env.margin_floor": "margin_floor",
"env.margin_floor_patience": "margin_floor_patience",
"env.n_sessions": "N",
"study.alpha": "alpha",
"study.lambda_coi": "lambda_coi",
"study.robust_radius": "robust_radius",
"study.robust_points": "robust_points",
"study.info_value": "info_value",
"study.revenue_weight": "revenue_weight",
"optimizer.learning_rate": "learning_rate",
"optimizer.gamma": "gamma",
"optimizer.batch_size": "batch_size",
"optimizer.n_steps": "n_steps",
"runtime.backend": "backend",
"runtime.device": "device",
"runtime.seed": "seed",
"runtime.total_timesteps": "total_timesteps",
"runtime.checkpoint_interval": "checkpoint_interval",
"eval.eval_freq": "eval_freq",
"eval.eval_episodes": "eval_episodes",
}
normalized: dict[str, Any] = {}
for key, value in raw.items():
canonical = alias_map.get(str(key), str(key))
normalized[canonical] = value
return normalized
@dataclass(frozen=True)
class AlgorithmSpec:
name: str = "ppo"
@dataclass(frozen=True)
class EnvSpec:
n_products: int = 10
n_sessions: int = 100
price_low: float = 10.0
price_high: float = 150.0
action_levels: int = 9
action_scale_low: float = 0.8
action_scale_high: float = 1.2
max_steps: int = 100
margin_floor: float = 0.05
margin_floor_patience: int = 5
@dataclass(frozen=True)
class StudySpec:
alpha: float = 0.3
lambda_coi: float = 0.2
robust_radius: float = 0.15
robust_points: int = 5
info_value: float = 1.0
revenue_weight: float = 0.01
no_robust: bool = False
@dataclass(frozen=True)
class OptimizerSpec:
learning_rate: float = 3e-4
gamma: float = 0.99
buffer_size: int = 50_000
batch_size: int = 256
tau: float = 0.005
train_freq: int = 1
learning_starts: int = 1_000
target_update_interval: int = 1_000
exploration_fraction: float = 0.2
exploration_final_eps: float = 0.05
n_steps: int = 2_048
n_epochs: int = 10
gae_lambda: float = 0.95
clip_range: float = 0.2
ent_coef: float = 0.0
q_lr: float = 0.1
q_bins: int = 6
eps_start: float = 1.0
eps_end: float = 0.05
eps_decay: float = 0.9995
arch: str = "small"
activation: str = "relu"
jax_num_envs: int = 16
jax_num_steps: int = 128
jax_num_minibatches: int = 4
jax_update_epochs: int = 4
jax_anneal_lr: bool = True
vf_coef: float = 0.5
max_grad_norm: float = 0.5
@dataclass(frozen=True)
class RuntimeSpec:
project: str = "capstone"
backend: str = "sb3"
device: str = "auto"
seed: int = 42
total_timesteps: int = 50_000
checkpoint_interval: int = 200_000
model_dir: str = "engine/models"
log_freq: int = 100
use_jax: bool = False
@dataclass(frozen=True)
class EvalSpec:
eval_freq: int = 1_000
eval_episodes: int = 5
robust_eval_enabled: bool = True
@dataclass(frozen=True)
class TrainSpec:
algorithm: AlgorithmSpec = field(default_factory=AlgorithmSpec)
env: EnvSpec = field(default_factory=EnvSpec)
study: StudySpec = field(default_factory=StudySpec)
optimizer: OptimizerSpec = field(default_factory=OptimizerSpec)
runtime: RuntimeSpec = field(default_factory=RuntimeSpec)
eval: EvalSpec = field(default_factory=EvalSpec)
def to_flat_dict(self) -> dict[str, Any]:
return {
"project": self.runtime.project,
"algo": self.algorithm.name,
"seed": self.runtime.seed,
"total_timesteps": self.runtime.total_timesteps,
"eval_episodes": self.eval.eval_episodes,
"eval_freq": self.eval.eval_freq,
"log_freq": self.runtime.log_freq,
"model_dir": self.runtime.model_dir,
"backend": self.runtime.backend,
"device": self.runtime.device,
"use_jax": self.runtime.use_jax,
"checkpoint_interval": self.runtime.checkpoint_interval,
"n_products": self.env.n_products,
"N": self.env.n_sessions,
"price_low": self.env.price_low,
"price_high": self.env.price_high,
"action_levels": self.env.action_levels,
"action_scale_low": self.env.action_scale_low,
"action_scale_high": self.env.action_scale_high,
"max_steps": self.env.max_steps,
"margin_floor": self.env.margin_floor,
"margin_floor_patience": self.env.margin_floor_patience,
"alpha": self.study.alpha,
"lambda_coi": self.study.lambda_coi,
"robust_radius": self.study.robust_radius,
"robust_points": self.study.robust_points,
"info_value": self.study.info_value,
"revenue_weight": self.study.revenue_weight,
"no_robust": self.study.no_robust,
"learning_rate": self.optimizer.learning_rate,
"gamma": self.optimizer.gamma,
"buffer_size": self.optimizer.buffer_size,
"batch_size": self.optimizer.batch_size,
"tau": self.optimizer.tau,
"train_freq": self.optimizer.train_freq,
"learning_starts": self.optimizer.learning_starts,
"target_update_interval": self.optimizer.target_update_interval,
"exploration_fraction": self.optimizer.exploration_fraction,
"exploration_final_eps": self.optimizer.exploration_final_eps,
"n_steps": self.optimizer.n_steps,
"n_epochs": self.optimizer.n_epochs,
"gae_lambda": self.optimizer.gae_lambda,
"clip_range": self.optimizer.clip_range,
"ent_coef": self.optimizer.ent_coef,
"q_lr": self.optimizer.q_lr,
"q_bins": self.optimizer.q_bins,
"eps_start": self.optimizer.eps_start,
"eps_end": self.optimizer.eps_end,
"eps_decay": self.optimizer.eps_decay,
"arch": self.optimizer.arch,
"activation": self.optimizer.activation,
"jax_num_envs": self.optimizer.jax_num_envs,
"jax_num_steps": self.optimizer.jax_num_steps,
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
"jax_update_epochs": self.optimizer.jax_update_epochs,
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
"vf_coef": self.optimizer.vf_coef,
"max_grad_norm": self.optimizer.max_grad_norm,
"robust_eval_enabled": self.eval.robust_eval_enabled,
}
@classmethod
def from_flat(
cls,
raw: Mapping[str, Any] | None = None,
*,
env_vars: Mapping[str, str] | None = None,
) -> "TrainSpec":
base = cls().to_flat_dict()
incoming = _normalize_keys(raw or {})
base.update({k: v for k, v in incoming.items() if v is not None})
runtime_env = os.environ if env_vars is None else env_vars
base["device"] = str(
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
)
requested_jax = _truthy(base.get("use_jax")) or _truthy(
runtime_env.get("PHANTOM_USE_JAX")
)
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
if backend == "auto":
backend = "jax" if requested_jax else "sb3"
if backend == "jax":
requested_jax = True
no_robust = _truthy(base.get("no_robust"))
if no_robust:
base["lambda_coi"] = 0.0
base["robust_radius"] = 0.0
base["robust_points"] = 1
return cls(
algorithm=AlgorithmSpec(name=str(base["algo"]).lower().strip()),
env=EnvSpec(
n_products=int(base["n_products"]),
n_sessions=int(base["N"]),
price_low=float(base["price_low"]),
price_high=float(base["price_high"]),
action_levels=int(base["action_levels"]),
action_scale_low=float(base["action_scale_low"]),
action_scale_high=float(base["action_scale_high"]),
max_steps=int(base["max_steps"]),
margin_floor=float(base["margin_floor"]),
margin_floor_patience=int(base["margin_floor_patience"]),
),
study=StudySpec(
alpha=float(base["alpha"]),
lambda_coi=float(base["lambda_coi"]),
robust_radius=float(base["robust_radius"]),
robust_points=int(base["robust_points"]),
info_value=float(base["info_value"]),
revenue_weight=float(base["revenue_weight"]),
no_robust=no_robust,
),
optimizer=OptimizerSpec(
learning_rate=float(base["learning_rate"]),
gamma=float(base["gamma"]),
buffer_size=int(base["buffer_size"]),
batch_size=int(base["batch_size"]),
tau=float(base["tau"]),
train_freq=int(base["train_freq"]),
learning_starts=int(base["learning_starts"]),
target_update_interval=int(base["target_update_interval"]),
exploration_fraction=float(base["exploration_fraction"]),
exploration_final_eps=float(base["exploration_final_eps"]),
n_steps=int(base["n_steps"]),
n_epochs=int(base["n_epochs"]),
gae_lambda=float(base["gae_lambda"]),
clip_range=float(base["clip_range"]),
ent_coef=float(base["ent_coef"]),
q_lr=float(base["q_lr"]),
q_bins=int(base["q_bins"]),
eps_start=float(base["eps_start"]),
eps_end=float(base["eps_end"]),
eps_decay=float(base["eps_decay"]),
arch=str(base["arch"]),
activation=str(base["activation"]),
jax_num_envs=int(base["jax_num_envs"]),
jax_num_steps=int(base["jax_num_steps"]),
jax_num_minibatches=int(base["jax_num_minibatches"]),
jax_update_epochs=int(base["jax_update_epochs"]),
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
vf_coef=float(base["vf_coef"]),
max_grad_norm=float(base["max_grad_norm"]),
),
runtime=RuntimeSpec(
project=str(base["project"]),
backend=backend,
device=str(base["device"]),
seed=int(base["seed"]),
total_timesteps=int(base["total_timesteps"]),
checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]),
use_jax=requested_jax,
),
eval=EvalSpec(
eval_freq=int(base["eval_freq"]),
eval_episodes=int(base["eval_episodes"]),
robust_eval_enabled=_truthy(base.get("robust_eval_enabled", True)),
),
)
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
return (
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
)
def run_metadata(
spec: TrainSpec,
*,
kind: str,
scenario: str,
group: str | None = None,
tags: Sequence[str] = (),
) -> dict[str, Any]:
metadata: dict[str, Any] = {
"run.kind": str(kind),
"run.algo": spec.algorithm.name,
"run.backend": spec.runtime.backend,
"run.device": spec.runtime.device,
"run.scenario": str(scenario),
"run.seed": spec.runtime.seed,
"run.tags": list(tags),
}
if group:
metadata["run.group"] = group
return metadata

View File

@@ -0,0 +1,136 @@
import sys
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from gymnasium.wrappers import FlattenObservation
from stable_baselines3 import PPO
# Add parent directory to path to allow importing engine
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from engine.wrapper import PHANTOM
from engine.lib.wrappers import EconomicMetricsWrapper
from engine.lib.providers import (
ProviderBenchmark,
BenchmarkConfig,
RandomBaseline,
SurgeBaseline,
)
def env_factory(alpha: float):
"""Creates a wrapped PHANTOM environment for testing at a specific alpha level."""
# Action levels=9 matches the trained PPO model
# n_products=8 matches the pretrained model's expectation of Box(16,)
env = PHANTOM(
n_products=8,
alpha=alpha,
N=100,
action_levels=9,
action_scale_low=0.8,
action_scale_high=1.2,
max_steps=20, # Short episodes so simulation goes fast
robust_points=1, # disable expensive adversarial lookaheads
render_mode=None,
)
env = EconomicMetricsWrapper(env)
return FlattenObservation(env)
def main():
print("Loading pre-trained Robust RL model...")
model_path = Path(__file__).parent.parent / "models" / "phantom_ppo.zip"
if not model_path.exists():
print(f"Error: Model not found at {model_path}")
print("Please ensure you have a trained model before running this script.")
return
rl_model = PPO.load(model_path)
# The action space is Discrete(9). Index 4 is the middle (1.0 scale).
n_actions = 9
mid_action = n_actions // 2
providers = {
"Static (Base)": lambda obs: mid_action,
"Random": RandomBaseline(n_actions),
"Heuristic Surge": SurgeBaseline(
n_actions, high_threshold=60.0, low_threshold=30.0
),
"Robust RL (PPO)": lambda obs: rl_model.predict(obs, deterministic=True)[0],
}
config = BenchmarkConfig(
n_episodes=10, # Lower episodes to run faster
alpha_range=[0.0, 0.5, 1.0], # Fewer alpha levels
baseline_name="Static (Base)",
)
print(f"\nStarting benchmark across alpha levels: {config.alpha_range}")
print(
f"Testing {len(providers)} strategies for {config.n_episodes} episodes each...\n"
)
benchmark = ProviderBenchmark(env_factory, providers, config)
results = benchmark.run()
# 1. Print tabular results
df = benchmark.to_dataframe()
summary = benchmark.summary_table()
print("\n--- Benchmark Summary Table ---")
print(summary)
# 2. Save results to CSV for thesis inclusion
out_dir = Path(__file__).parent / "results"
out_dir.mkdir(exist_ok=True)
csv_path = out_dir / "provider_comparison.csv"
df.to_csv(csv_path, index=False)
print(f"\nSaved raw results to {csv_path}")
# 3. Plot the degradation of COI / Revenue as alpha increases
plt.figure(figsize=(12, 5))
# Plot 1: Revenue vs Alpha
plt.subplot(1, 2, 1)
for name in providers.keys():
provider_data = df[df["name"] == name]
plt.plot(
provider_data["alpha"],
provider_data["mean_revenue"],
marker="o",
label=name,
linewidth=2,
)
plt.title("Revenue under Agent Contamination")
plt.xlabel("Contamination Level (α)")
plt.ylabel("Mean Episode Revenue ($)")
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
# Plot 2: COI Preservation vs Alpha
plt.subplot(1, 2, 2)
for name in providers.keys():
provider_data = df[df["name"] == name]
plt.plot(
provider_data["alpha"],
provider_data["coi_preserved_pct"],
marker="s",
label=name,
linewidth=2,
)
plt.title("Cost of Information (COI) Preservation")
plt.xlabel("Contamination Level (α)")
plt.ylabel("COI Preserved (%)")
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
plt.tight_layout()
plot_path = out_dir / "alpha_degradation_plot.png"
plt.savefig(plot_path, dpi=300)
print(f"Saved visualization to {plot_path}")
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,6 @@
method: random method: random
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
command: command:
- ${env} - ${env}

View File

@@ -1,6 +1,6 @@
method: grid method: grid
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
run_cap: 4 run_cap: 4
command: command:

View File

@@ -1,6 +1,6 @@
method: bayes method: bayes
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
command: command:
- ${env} - ${env}

View File

@@ -1,6 +1,6 @@
method: random method: random
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
command: command:
- ${env} - ${env}

View File

@@ -1,6 +1,6 @@
method: bayes method: bayes
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
command: command:
- ${env} - ${env}

View File

@@ -1,6 +1,6 @@
method: bayes method: bayes
metric: metric:
name: sweep/score name: objective/score
goal: maximize goal: maximize
command: command:
- ${env} - ${env}

View File

@@ -0,0 +1,23 @@
from .metrics import canonicalize_metrics
from .wandb import (
current_config,
finish_run,
get_wandb_module,
init_run,
log_metrics,
run_agent,
update_run_config,
update_summary,
)
__all__ = [
"canonicalize_metrics",
"current_config",
"finish_run",
"get_wandb_module",
"init_run",
"log_metrics",
"run_agent",
"update_run_config",
"update_summary",
]

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from typing import Any, Mapping
from ..spec import TrainSpec
_ALIASES = {
"train/reward": "train/reward_mean",
"train/revenue": "train/revenue_mean",
"train/dqn_loss": "train/loss",
"eval/reward": "eval/reward_mean",
"eval/revenue": "eval/revenue_mean",
"train/steps_per_second": "runtime/steps_per_second",
}
def _as_float(value: Any, default: float | None = None) -> float | None:
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, Any]:
metrics: dict[str, Any] = {}
for key, value in raw.items():
canonical = _ALIASES.get(str(key), str(key))
if canonical in metrics and canonical != key:
continue
metrics[canonical] = value
metrics.setdefault("train/global_step", spec.runtime.total_timesteps)
eval_reward = _as_float(metrics.get("eval/reward_mean"), 0.0) or 0.0
eval_revenue = _as_float(metrics.get("eval/revenue_mean"), 0.0) or 0.0
metrics["objective/score"] = eval_reward + spec.study.revenue_weight * eval_revenue
margin_mean = _as_float(metrics.get("eval/margin_mean"), None)
if margin_mean is not None:
metrics["objective/constraint_margin"] = margin_mean - spec.env.margin_floor
coi_level = _as_float(metrics.get("eval/coi_level_mean"), None)
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
metrics["study/alpha"] = spec.study.alpha
metrics["study/lambda_coi"] = spec.study.lambda_coi
metrics["study/robust_radius"] = spec.study.robust_radius
metrics["study/info_value"] = spec.study.info_value
metrics["runtime/backend"] = spec.runtime.backend
metrics["runtime/device"] = spec.runtime.device
metrics["runtime/seed"] = spec.runtime.seed
return metrics

98
engine/telemetry/wandb.py Normal file
View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any, Callable, Iterable, Mapping
def get_wandb_module():
try:
import wandb
return wandb
except ImportError:
return None
def _require_wandb():
wandb = get_wandb_module()
if wandb is None:
raise ImportError("wandb is required for this workflow")
return wandb
def init_run(
*,
mode: str,
project: str | None = None,
config: Mapping[str, Any] | None = None,
name: str | None = None,
tags: Iterable[str] | None = None,
group: str | None = None,
sweep_mode: bool = False,
):
wandb = _require_wandb()
kwargs: dict[str, Any] = {"mode": mode}
if group:
kwargs["group"] = group
if sweep_mode:
run = wandb.init(**kwargs)
if name and run is not None:
run.name = name
return run
init_kwargs = dict(kwargs)
init_kwargs["project"] = project
if config is not None:
init_kwargs["config"] = dict(config)
if name:
init_kwargs["name"] = name
if tags:
init_kwargs["tags"] = list(tags)
return wandb.init(**init_kwargs)
def finish_run() -> None:
wandb = get_wandb_module()
if wandb is not None and wandb.run is not None:
wandb.finish()
def current_config() -> dict[str, Any]:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return {}
return {key: wandb.config[key] for key in wandb.config.keys()}
def update_run_config(config: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
try:
wandb.config.update(dict(config), allow_val_change=True)
except TypeError:
wandb.config.update(dict(config))
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
wandb.log(dict(metrics), step=step)
def update_summary(metrics: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
for key, value in metrics.items():
wandb.run.summary[key] = value
def run_agent(
sweep_id: str,
fn: Callable[[], None],
*,
count: int | None = None,
) -> None:
wandb = _require_wandb()
wandb.agent(sweep_id, function=fn, count=count)

View File

@@ -1,98 +1,10 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import json from typing import Any
import os
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING: from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
from .lib.discrete import EventQTable from .spec import TrainSpec
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try:
import wandb as _wandb
if hasattr(_wandb, "init") and callable(_wandb.init):
wandb = _wandb
HAS_WANDB = True
else:
wandb = None
HAS_WANDB = False
except ImportError:
wandb = None
HAS_WANDB = False
try:
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
HAS_SB3 = True
except ImportError:
HAS_SB3 = False
from .jax import JAX_AVAILABLE
DEFAULT_CFG = {
"project": "phantom-pricing",
"algo": "ppo",
"seed": 42,
"total_timesteps": 50_000,
"eval_episodes": 5,
"eval_freq": 1_000,
"log_freq": 100,
"revenue_weight": 0.01,
"n_products": 10,
"N": 100,
"alpha": 0.3,
"lambda_coi": 0.2,
"robust_radius": 0.15,
"robust_points": 5,
"no_robust": False,
"info_value": 1.0,
"price_low": 10.0,
"price_high": 150.0,
"action_levels": 9,
"action_scale_low": 0.8,
"action_scale_high": 1.2,
"learning_rate": 3e-4,
"gamma": 0.99,
"buffer_size": 50_000,
"batch_size": 256,
"tau": 0.005,
"train_freq": 1,
"learning_starts": 1_000,
"target_update_interval": 1_000,
"exploration_fraction": 0.2,
"exploration_final_eps": 0.05,
"n_steps": 2_048,
"n_epochs": 10,
"gae_lambda": 0.95,
"clip_range": 0.2,
"ent_coef": 0.0,
"q_lr": 0.1,
"eps_start": 1.0,
"eps_end": 0.05,
"eps_decay": 0.9995,
"model_dir": "engine/models",
"arch": "small",
"activation": "relu",
"q_bins": 6,
"max_steps": 100,
"margin_floor": 0.05,
"margin_floor_patience": 5,
"use_jax": False,
"jax_num_envs": 16,
"jax_num_steps": 128,
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
"checkpoint_interval": 200_000,
}
def _truthy(value: str | bool | None) -> bool: def _truthy(value: str | bool | None) -> bool:
@@ -103,423 +15,133 @@ def _truthy(value: str | bool | None) -> bool:
return str(value).strip().lower() in {"1", "true", "yes", "on"} return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _cfg(raw: dict | None = None) -> dict: def _parse_tags(raw: str | None) -> list[str]:
cfg = dict(DEFAULT_CFG) if raw is None:
if raw: return []
cfg.update({k: v for k, v in raw.items() if v is not None}) return [piece.strip() for piece in str(raw).split(",") if piece.strip()]
cfg["algo"] = str(cfg["algo"]).lower()
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
os.environ.get("PHANTOM_USE_JAX") def _probe_run_kind(argv: list[str]) -> str:
probe = argparse.ArgumentParser(add_help=False)
probe.add_argument("--run-kind", choices=["train", "benchmark"])
probe.add_argument("--run-mode", choices=["train", "benchmark"])
args, _ = probe.parse_known_args(argv)
return str(args.run_kind or args.run_mode or "train")
def _strip_run_kind(argv: list[str]) -> list[str]:
stripped: list[str] = []
skip_next = False
for item in argv:
if skip_next:
skip_next = False
continue
if item in {"--run-kind", "--run-mode"}:
skip_next = True
continue
if item.startswith("--run-kind=") or item.startswith("--run-mode="):
continue
stripped.append(item)
return stripped
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="PHANTOM unified training entrypoint")
parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train")
parser.add_argument("--run-mode", choices=["train", "benchmark"])
parser.add_argument("--project", default="capstone")
parser.add_argument("--scenario", default="default")
parser.add_argument("--group", type=str)
parser.add_argument("--tags", type=str)
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto")
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
parser.add_argument("--seed", type=int)
parser.add_argument("--total-timesteps", type=int)
parser.add_argument("--model-dir", type=str)
parser.add_argument("--log-freq", type=int)
parser.add_argument("--checkpoint-interval", type=int)
parser.add_argument("--device", type=str)
parser.add_argument("--alpha", type=float)
parser.add_argument("--N", type=int)
parser.add_argument("--n-products", type=int)
parser.add_argument("--lambda-coi", type=float)
parser.add_argument("--info-value", type=float)
parser.add_argument("--robust-radius", type=float)
parser.add_argument("--robust-points", type=int)
parser.add_argument("--no-robust", action="store_true")
parser.add_argument("--revenue-weight", type=float)
parser.add_argument("--price-low", type=float)
parser.add_argument("--price-high", type=float)
parser.add_argument("--action-levels", type=int)
parser.add_argument("--action-scale-low", type=float)
parser.add_argument("--action-scale-high", type=float)
parser.add_argument("--max-steps", type=int)
parser.add_argument("--margin-floor", type=float)
parser.add_argument("--margin-floor-patience", type=int)
parser.add_argument("--learning-rate", type=float)
parser.add_argument("--gamma", type=float)
parser.add_argument("--buffer-size", type=int)
parser.add_argument("--batch-size", type=int)
parser.add_argument("--tau", type=float)
parser.add_argument("--train-freq", type=int)
parser.add_argument("--learning-starts", type=int)
parser.add_argument("--target-update-interval", type=int)
parser.add_argument("--exploration-fraction", type=float)
parser.add_argument("--exploration-final-eps", type=float)
parser.add_argument("--n-steps", type=int)
parser.add_argument("--n-epochs", type=int)
parser.add_argument("--gae-lambda", type=float)
parser.add_argument("--clip-range", type=float)
parser.add_argument("--ent-coef", type=float)
parser.add_argument("--q-lr", type=float)
parser.add_argument("--q-bins", type=int)
parser.add_argument("--eps-start", type=float)
parser.add_argument("--eps-end", type=float)
parser.add_argument("--eps-decay", type=float)
parser.add_argument("--arch", type=str)
parser.add_argument("--activation", type=str)
parser.add_argument("--vf-coef", type=float)
parser.add_argument("--max-grad-norm", type=float)
parser.add_argument("--eval-freq", type=int)
parser.add_argument("--eval-episodes", type=int)
parser.add_argument("--jax", action="store_true")
parser.add_argument("--jax-num-envs", type=int)
parser.add_argument("--jax-num-steps", type=int)
parser.add_argument("--jax-num-minibatches", type=int)
parser.add_argument("--jax-update-epochs", type=int)
parser.add_argument("--jax-anneal-lr", type=str)
parser.add_argument("--sweep-agent", action="store_true")
parser.add_argument("--sweep-id", type=str)
parser.add_argument("--count", type=int, default=0)
parser.add_argument("--offline", action="store_true")
parser.add_argument("--no-wandb", action="store_true")
return parser
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
jax_anneal_lr = (
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
) )
cfg["no_robust"] = _truthy(cfg.get("no_robust")) backend = None if args.backend == "auto" else args.backend
if cfg["no_robust"]:
cfg["lambda_coi"] = 0.0
cfg["robust_radius"] = 0.0
cfg["robust_points"] = 1
return cfg
def _wandb_cfg_dict() -> dict:
return (
{k: wandb.config[k] for k in wandb.config.keys()}
if HAS_WANDB and wandb.run
else {}
)
def make_env(cfg: dict):
from gymnasium.wrappers import FlattenObservation
from .wrapper import PHANTOM
from .lib.wrappers import EconomicMetricsWrapper
env = PHANTOM(
n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]),
N=int(cfg["N"]),
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
lambda_coi=float(cfg["lambda_coi"]),
robust_radius=float(cfg["robust_radius"]),
robust_points=int(cfg["robust_points"]),
info_value=float(cfg["info_value"]),
action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None,
)
env = EconomicMetricsWrapper(env)
env = FlattenObservation(env)
return env
def _net_arch(name) -> list[int]:
presets = {
"tiny": [32, 32],
"small": [64, 64],
"medium": [128, 128],
"large": [256, 256],
}
if isinstance(name, (list, tuple)):
return [int(v) for v in name]
s = str(name).lower().strip()
if s in presets:
return presets[s]
if "x" in s:
try:
vals = [int(v) for v in s.split("x") if v]
return vals if vals else presets["small"]
except ValueError:
return presets["small"]
return presets["small"]
def _activation(name):
try:
import torch.nn as nn
except ImportError:
return None
return {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
}.get(str(name).lower().strip(), nn.ReLU)
def _policy_kwargs(cfg: dict) -> dict:
kw = {"net_arch": _net_arch(cfg.get("arch", "small"))}
act = _activation(cfg.get("activation", "relu"))
if act is not None:
kw["activation_fn"] = act
return kw
def _action(agent, obs, deterministic: bool = True):
out = agent.predict(obs, deterministic=deterministic)
a = out[0] if isinstance(out, tuple) else out
if isinstance(a, np.ndarray) and a.size == 1:
return int(a.reshape(-1)[0])
return a
def evaluate(agent, env, episodes: int) -> dict:
rewards, revenues = [], []
for _ in range(int(episodes)):
obs, _ = env.reset()
done, ep_r, ep_rev = False, 0.0, 0.0
while not done:
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
done = term or trunc
ep_r += float(reward)
ep_rev += float(
info.get("economics", {}).get("revenue", info.get("revenue", 0.0))
)
rewards.append(ep_r)
revenues.append(ep_rev)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
def build_model(cfg: dict, env):
algo = cfg["algo"]
policy_kwargs = _policy_kwargs(cfg)
if algo == "sac":
raise ValueError("sac is not supported with the discrete core env")
if algo == "ppo":
return PPO(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
n_steps=int(cfg["n_steps"]),
batch_size=int(cfg["batch_size"]),
n_epochs=int(cfg["n_epochs"]),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
clip_range=float(cfg["clip_range"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "a2c":
return A2C(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
n_steps=max(5, int(cfg["n_steps"]) // 32),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "dqn":
return DQN(
"MlpPolicy",
env,
verbose=1,
policy_kwargs=policy_kwargs,
seed=int(cfg["seed"]),
learning_rate=float(cfg["learning_rate"]),
buffer_size=int(cfg["buffer_size"]),
batch_size=int(cfg["batch_size"]),
gamma=float(cfg["gamma"]),
train_freq=int(cfg["train_freq"]),
learning_starts=int(cfg["learning_starts"]),
target_update_interval=int(cfg["target_update_interval"]),
exploration_fraction=float(cfg["exploration_fraction"]),
exploration_final_eps=float(cfg["exploration_final_eps"]),
)
raise ValueError(f"unsupported algo '{algo}'")
def _sb3_model_cls(algo: str):
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def train_qtable(cfg: dict) -> tuple["EventQTable", dict]:
from .lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
env = make_env(cfg)
eval_env = make_env(cfg)
agent = EventQTable(
env.action_space.n,
int(cfg["n_products"]),
(float(cfg["price_low"]), float(cfg["price_high"])),
lr=float(cfg["q_lr"]),
gamma=float(cfg["gamma"]),
n_bins=int(cfg["q_bins"]),
)
eps = float(cfg["eps_start"])
obs, _ = env.reset(seed=int(cfg["seed"]))
for t in range(int(cfg["total_timesteps"])):
a, s = agent.act(obs, eps)
nxt, reward, term, trunc, info = env.step(a)
done = term or trunc
agent.update(s, a, float(reward), agent.encode(nxt), done)
eps = max(float(cfg["eps_end"]), eps * float(cfg["eps_decay"]))
if HAS_WANDB and wandb.run and (t + 1) % int(cfg["log_freq"]) == 0:
econ = info.get("economics", {})
wandb.log(
{
"train/reward": float(reward),
"train/revenue": float(econ.get("revenue", 0.0)),
"train/epsilon": float(eps),
},
step=t + 1,
)
obs = env.reset()[0] if done else nxt
metrics = evaluate(agent, eval_env, int(cfg["eval_episodes"]))
metrics["train/global_step"] = int(cfg["total_timesteps"])
env.close()
eval_env.close()
return agent, metrics
def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
env = make_env(cfg)
eval_env = make_env(cfg)
env = Monitor(env)
eval_env = Monitor(eval_env)
model = build_model(cfg, env)
resume_step = 0
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is not None:
checkpoint_path, metadata = restored
model = _sb3_model_cls(cfg["algo"]).load(
checkpoint_path.as_posix(), env=env
)
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
model.num_timesteps = max(
int(getattr(model, "num_timesteps", 0)), resume_step
)
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
cbs.append(
CheckpointArtifactCallback(
cfg,
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
cbs.append(
EvalCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
deterministic=True,
verbose=0,
)
)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=cbs,
reset_num_timesteps=False,
)
model_path = Path(cfg["model_dir"])
model_path.mkdir(parents=True, exist_ok=True)
model.save(str(model_path / f"phantom_{cfg['algo']}"))
metrics = evaluate(model, eval_env, int(cfg["eval_episodes"]))
metrics["train/global_step"] = int(model.num_timesteps)
env.close()
eval_env.close()
return model, metrics
def train_once(cfg: dict) -> dict:
algo = cfg["algo"]
if cfg.get("use_jax"):
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
_, metrics = train_jax(cfg)
elif algo == "qtable":
_, metrics = train_qtable(cfg)
else:
_, metrics = train_sb3(cfg)
metrics["sweep/score"] = float(
metrics["eval/reward"] + float(cfg["revenue_weight"]) * metrics["eval/revenue"]
)
return metrics
def run_wandb(
project: str, overrides: dict, mode: str = "online", sweep_mode: bool = False
) -> dict:
if not HAS_WANDB:
raise ImportError("wandb is required for sweep runs")
if not sweep_mode:
pre_cfg = _cfg(overrides)
if pre_cfg.get("use_jax"):
try:
import jax
if jax.process_count() > 1 and jax.process_index() != 0:
return train_once(pre_cfg)
except Exception:
pass
init_kwargs = {"mode": mode}
if sweep_mode:
run = wandb.init(**init_kwargs)
else:
run = wandb.init(project=project, config=overrides, **init_kwargs)
try:
cfg = _cfg(_wandb_cfg_dict())
if sweep_mode:
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
return metrics
finally:
if wandb.run is not None:
wandb.finish()
def run_local(overrides: dict) -> dict:
cfg = _cfg(overrides)
metrics = train_once(cfg)
should_print = True
if cfg.get("use_jax"):
try:
import jax
should_print = jax.process_index() == 0
except Exception:
should_print = True
if should_print:
print(json.dumps(metrics, indent=2))
# sentinel line for machine-readable extraction; must stay on one line
print("PHANTOM_METRICS:" + json.dumps(metrics))
return metrics
def main():
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
p.add_argument("--project", default=DEFAULT_CFG["project"])
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
p.add_argument("--seed", type=int)
p.add_argument("--total-timesteps", type=int)
p.add_argument("--alpha", type=float)
p.add_argument("--N", type=int)
p.add_argument("--n-products", type=int)
p.add_argument("--lambda-coi", type=float)
p.add_argument("--info-value", type=float)
p.add_argument("--robust-radius", type=float)
p.add_argument("--robust-points", type=int)
p.add_argument("--no-robust", action="store_true")
p.add_argument("--learning-rate", type=float)
p.add_argument("--gamma", type=float)
p.add_argument("--gae-lambda", type=float)
p.add_argument("--clip-range", type=float)
p.add_argument("--ent-coef", type=float)
p.add_argument("--revenue-weight", type=float)
p.add_argument("--price-low", type=float)
p.add_argument("--price-high", type=float)
p.add_argument("--action-levels", type=int)
p.add_argument("--action-scale-low", type=float)
p.add_argument("--action-scale-high", type=float)
p.add_argument("--max-steps", type=int)
p.add_argument("--margin-floor", type=float)
p.add_argument("--margin-floor-patience", type=int)
p.add_argument("--arch", type=str)
p.add_argument("--activation", type=str)
p.add_argument("--jax", action="store_true")
p.add_argument("--jax-num-envs", type=int)
p.add_argument("--jax-num-steps", type=int)
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--checkpoint-interval", type=int)
p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0)
p.add_argument("--offline", action="store_true")
p.add_argument("--no-wandb", action="store_true")
args = p.parse_args()
overrides = { overrides = {
"project": args.project,
"backend": backend,
"algo": args.algo, "algo": args.algo,
"seed": args.seed, "seed": args.seed,
"total_timesteps": args.total_timesteps, "total_timesteps": args.total_timesteps,
"model_dir": args.model_dir,
"log_freq": args.log_freq,
"checkpoint_interval": args.checkpoint_interval,
"device": args.device,
"alpha": args.alpha, "alpha": args.alpha,
"N": args.N, "N": args.N,
"n_products": args.n_products, "n_products": args.n_products,
@@ -528,11 +150,6 @@ def main():
"robust_radius": args.robust_radius, "robust_radius": args.robust_radius,
"robust_points": args.robust_points, "robust_points": args.robust_points,
"no_robust": args.no_robust, "no_robust": args.no_robust,
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"gae_lambda": args.gae_lambda,
"clip_range": args.clip_range,
"ent_coef": args.ent_coef,
"revenue_weight": args.revenue_weight, "revenue_weight": args.revenue_weight,
"price_low": args.price_low, "price_low": args.price_low,
"price_high": args.price_high, "price_high": args.price_high,
@@ -542,40 +159,87 @@ def main():
"max_steps": args.max_steps, "max_steps": args.max_steps,
"margin_floor": args.margin_floor, "margin_floor": args.margin_floor,
"margin_floor_patience": args.margin_floor_patience, "margin_floor_patience": args.margin_floor_patience,
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"buffer_size": args.buffer_size,
"batch_size": args.batch_size,
"tau": args.tau,
"train_freq": args.train_freq,
"learning_starts": args.learning_starts,
"target_update_interval": args.target_update_interval,
"exploration_fraction": args.exploration_fraction,
"exploration_final_eps": args.exploration_final_eps,
"n_steps": args.n_steps,
"n_epochs": args.n_epochs,
"gae_lambda": args.gae_lambda,
"clip_range": args.clip_range,
"ent_coef": args.ent_coef,
"q_lr": args.q_lr,
"q_bins": args.q_bins,
"eps_start": args.eps_start,
"eps_end": args.eps_end,
"eps_decay": args.eps_decay,
"arch": args.arch, "arch": args.arch,
"activation": args.activation, "activation": args.activation,
"use_jax": args.jax, "vf_coef": args.vf_coef,
"max_grad_norm": args.max_grad_norm,
"eval_freq": args.eval_freq,
"eval_episodes": args.eval_episodes,
"use_jax": args.jax or None,
"jax_num_envs": args.jax_num_envs, "jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps, "jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches, "jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs, "jax_update_epochs": args.jax_update_epochs,
"checkpoint_interval": args.checkpoint_interval, "jax_anneal_lr": jax_anneal_lr,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,
} }
overrides = {k: v for k, v in overrides.items() if v is not None} return {key: value for key, value in overrides.items() if value is not None}
def main(argv: list[str] | None = None) -> None:
import sys
raw_args = list(sys.argv[1:] if argv is None else argv)
run_kind = _probe_run_kind(raw_args)
if run_kind == "benchmark":
run_benchmark_cli(_strip_run_kind(raw_args))
return
parser = _build_parser()
args, unknown = parser.parse_known_args(raw_args)
if unknown:
raise ValueError(f"Unknown arguments for training mode: {' '.join(unknown)}")
overrides = _overrides_from_args(args)
scenario = str(args.scenario)
group = args.group
extra_tags = tuple(_parse_tags(args.tags))
if args.sweep_agent: if args.sweep_agent:
if args.no_wandb: run_sweep_agent(
raise ValueError("sweep agent requires wandb") project=args.project,
if not args.sweep_id: sweep_id=str(args.sweep_id or ""),
raise ValueError("--sweep-id is required with --sweep-agent") count=int(args.count),
mode = "offline" if args.offline else "online" offline=bool(args.offline),
wandb.agent( no_wandb=bool(args.no_wandb),
args.sweep_id, base_overrides=overrides,
function=lambda: run_wandb( kind="sweep",
args.project, overrides, mode=mode, sweep_mode=True scenario=scenario,
), group=group,
count=args.count if args.count > 0 else None, extra_tags=extra_tags,
) )
return return
if args.no_wandb or not HAS_WANDB: spec = TrainSpec.from_flat(overrides)
run_local(overrides) run_train_once(
return spec,
project=args.project,
run_wandb(args.project, overrides, mode="offline" if args.offline else "online") offline=bool(args.offline),
no_wandb=bool(args.no_wandb),
kind="train",
scenario=scenario,
group=group,
extra_tags=extra_tags,
)
if __name__ == "__main__": if __name__ == "__main__":

40
engine/train_core.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .spec import TrainSpec
from .telemetry.metrics import canonicalize_metrics
@dataclass(frozen=True)
class TrainResult:
spec: TrainSpec
metrics: dict[str, Any]
artifacts: dict[str, str]
def run_train(spec: TrainSpec) -> TrainResult:
cfg = spec.to_flat_dict()
algo = spec.algorithm.name
if spec.runtime.use_jax or spec.runtime.backend == "jax":
from .backends.jax import train_jax_backend
_, raw_metrics = train_jax_backend(cfg)
elif algo == "qtable":
from .backends.qtable import train_qtable
_, raw_metrics = train_qtable(cfg)
else:
from .backends.sb3 import train_sb3
_, raw_metrics = train_sb3(cfg)
metrics = canonicalize_metrics(raw_metrics, spec)
artifacts: dict[str, str] = {}
model_path = raw_metrics.get("model/path")
if isinstance(model_path, str):
artifacts["model/path"] = model_path
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)

View File

@@ -381,7 +381,7 @@ if __name__ == "__main__":
def predict(self, obs, **kwargs): def predict(self, obs, **kwargs):
return self.env.action_space.sample(), None return self.env.action_space.sample(), None
wandb.init(project="phantom-pricing", config={"policy": "random", "alpha": 0.3}) wandb.init(project="capstone", config={"policy": "random", "alpha": 0.3})
env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None)) env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None))
model = RandomPolicy(env) model = RandomPolicy(env)

View File

@@ -55,6 +55,9 @@
"train": { "train": {
"cache": false "cache": false
}, },
"benchmark": {
"cache": false
},
"up": { "up": {
"cache": false "cache": false
}, },

View File

@@ -19,6 +19,7 @@
"platform:down": "nx run platform:down", "platform:down": "nx run platform:down",
"platform:logs": "nx run platform:logs", "platform:logs": "nx run platform:logs",
"research:test": "nx run research:test", "research:test": "nx run research:test",
"research:benchmark": "nx run research:benchmark",
"e2e:test": "nx run e2e:test" "e2e:test": "nx run e2e:test"
}, },
"devDependencies": { "devDependencies": {

View File

@@ -30,10 +30,20 @@ case "$cmd" in
load_sweep_env load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
WANDB_ENTITY="${WANDB_ENTITY:-}" \ WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \ WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000} .venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000}
;; ;;
benchmark)
load_sweep_env
if [[ " ${LOCAL_BENCHMARK_ARGS:-} " != *" --no-wandb "* ]]; then
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="${WANDB_API_KEY:-}" \
.venv/bin/python -m engine.train --run-kind benchmark ${LOCAL_BENCHMARK_ARGS:---tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu}
;;
train-agent) train-agent)
load_sweep_env load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
@@ -43,10 +53,23 @@ case "$cmd" in
args+=(--count "$AGENT_COUNT") args+=(--count "$AGENT_COUNT")
fi fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \ WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \ WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train "${args[@]}" .venv/bin/python -m engine.train "${args[@]}"
;; ;;
benchmark-agent)
load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
args=(--sweep-agent --sweep-id "$SWEEP_ID")
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
args+=(--count "$AGENT_COUNT")
fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train --run-kind benchmark "${args[@]}" ${BENCHMARK_AGENT_ARGS:-}
;;
train-bootstrap) train-bootstrap)
load_sweep_env load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
@@ -55,7 +78,7 @@ case "$cmd" in
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
WANDB_API_KEY="$WANDB_API_KEY" \ WANDB_API_KEY="$WANDB_API_KEY" \
WANDB_ENTITY="${WANDB_ENTITY:-}" \ WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \ WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
GITHUB_TOKEN="$GITHUB_TOKEN" \ GITHUB_TOKEN="$GITHUB_TOKEN" \
REPO_URL="$REPO_URL" \ REPO_URL="$REPO_URL" \
BRANCH="${BRANCH:-main}" \ BRANCH="${BRANCH:-main}" \
@@ -115,7 +138,7 @@ PY
train-tpu-vm-sweep) train-tpu-vm-sweep)
load_sweep_env load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
args=( args=(
--sweep-id "$SWEEP_ID" --sweep-id "$SWEEP_ID"

View File

@@ -96,7 +96,11 @@ def _extract_metrics(output: str) -> dict:
obj = json.loads(block) obj = json.loads(block)
except Exception: except Exception:
continue continue
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj): if isinstance(obj, dict) and (
"objective/score" in obj
or "eval/reward_mean" in obj
or "sweep/score" in obj
):
return obj return obj
return {} return {}