mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Required for wandb runs and sweep agent workers.
|
||||
WANDB_API_KEY=
|
||||
WANDB_ENTITY=
|
||||
WANDB_PROJECT=phantom-pricing
|
||||
WANDB_PROJECT=capstone
|
||||
|
||||
# Required for private repo bootstrap workers.
|
||||
GITHUB_TOKEN=
|
||||
@@ -16,3 +16,7 @@ GITHUB_TOKEN=
|
||||
# AGENT_COUNT=0
|
||||
# AGENT_LOOP=1
|
||||
# RETRY_SECONDS=20
|
||||
|
||||
# Optional local benchmark defaults.
|
||||
# LOCAL_BENCHMARK_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu
|
||||
# BENCHMARK_AGENT_ARGS=--tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3,0.6 --episodes 5
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -68,6 +68,7 @@ sim/case/thesis_simplified/runs*/
|
||||
|
||||
# model binaries
|
||||
engine/models/*.zip
|
||||
engine/studies/results/*
|
||||
*.zip
|
||||
|
||||
# wandb local state
|
||||
|
||||
17
Makefile
17
Makefile
@@ -13,9 +13,11 @@ NX := npx nx
|
||||
SWEEP_ENV_FILE ?= .env.sweep
|
||||
|
||||
WANDB_ENTITY ?=
|
||||
WANDB_PROJECT ?= phantom-pricing
|
||||
WANDB_PROJECT ?= capstone
|
||||
SWEEP_ID ?=
|
||||
LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000
|
||||
LOCAL_BENCHMARK_ARGS ?= --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu
|
||||
BENCHMARK_AGENT_ARGS ?=
|
||||
AGENT_COUNT ?= 0
|
||||
|
||||
REPO_URL ?=
|
||||
@@ -36,7 +38,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
|
||||
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
|
||||
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
|
||||
@echo ""
|
||||
@echo "Build general public version:"
|
||||
@@ -45,6 +47,9 @@ help:
|
||||
@echo "Local wandb run:"
|
||||
@echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'"
|
||||
@echo ""
|
||||
@echo "Local benchmark run:"
|
||||
@echo " make benchmark LOCAL_BENCHMARK_ARGS='--tiers static,surge,linear --alpha-values 0.0,0.3 --episodes 3 --no-wandb'"
|
||||
@echo ""
|
||||
@echo "Local sweep agent from this repo:"
|
||||
@echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5"
|
||||
@echo ""
|
||||
@@ -104,6 +109,14 @@ install:
|
||||
train:
|
||||
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train
|
||||
|
||||
.PHONY: benchmark
|
||||
benchmark:
|
||||
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_BENCHMARK_ARGS="$(LOCAL_BENCHMARK_ARGS)" $(NX) run research:benchmark
|
||||
|
||||
.PHONY: benchmark.agent
|
||||
benchmark.agent:
|
||||
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" BENCHMARK_AGENT_ARGS="$(BENCHMARK_AGENT_ARGS)" $(NX) run research:benchmark-agent
|
||||
|
||||
.PHONY: train.agent
|
||||
train.agent:
|
||||
@WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-agent
|
||||
|
||||
1
engine/backends/__init__.py
Normal file
1
engine/backends/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"]
|
||||
81
engine/backends/common.py
Normal file
81
engine/backends/common.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_env(cfg: Mapping[str, Any]):
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from ..lib.wrappers import EconomicMetricsWrapper
|
||||
from ..wrapper import PHANTOM
|
||||
|
||||
env = PHANTOM(
|
||||
n_products=int(cfg["n_products"]),
|
||||
alpha=float(cfg["alpha"]),
|
||||
N=int(cfg["N"]),
|
||||
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
|
||||
lambda_coi=float(cfg["lambda_coi"]),
|
||||
robust_radius=float(cfg["robust_radius"]),
|
||||
robust_points=int(cfg["robust_points"]),
|
||||
info_value=float(cfg["info_value"]),
|
||||
action_levels=int(cfg["action_levels"]),
|
||||
action_scale_low=float(cfg["action_scale_low"]),
|
||||
action_scale_high=float(cfg["action_scale_high"]),
|
||||
max_steps=int(cfg.get("max_steps", 100)),
|
||||
margin_floor=float(cfg.get("margin_floor", 0.05)),
|
||||
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
|
||||
render_mode=None,
|
||||
)
|
||||
env = EconomicMetricsWrapper(env)
|
||||
return FlattenObservation(env)
|
||||
|
||||
|
||||
def _action(agent: Any, obs: Any, deterministic: bool = True):
|
||||
out = agent.predict(obs, deterministic=deterministic)
|
||||
action = out[0] if isinstance(out, tuple) else out
|
||||
if isinstance(action, np.ndarray) and action.size == 1:
|
||||
return int(action.reshape(-1)[0])
|
||||
return action
|
||||
|
||||
|
||||
def evaluate(agent: Any, env: Any, episodes: int) -> dict[str, float]:
|
||||
rewards: list[float] = []
|
||||
revenues: list[float] = []
|
||||
margins: list[float] = []
|
||||
coi_levels: list[float] = []
|
||||
|
||||
for _ in range(int(episodes)):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_revenue = 0.0
|
||||
ep_margin = 0.0
|
||||
ep_coi = 0.0
|
||||
steps = 0
|
||||
|
||||
while not done:
|
||||
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
|
||||
done = bool(term or trunc)
|
||||
econ = info.get("economics", {})
|
||||
ep_reward += float(reward)
|
||||
ep_revenue += float(econ.get("revenue", info.get("revenue", 0.0)))
|
||||
ep_margin += float(econ.get("margin", 0.0))
|
||||
ep_coi += float(econ.get("coi_level", 0.0))
|
||||
steps += 1
|
||||
|
||||
rewards.append(ep_reward)
|
||||
revenues.append(ep_revenue)
|
||||
denom = max(steps, 1)
|
||||
margins.append(ep_margin / denom)
|
||||
coi_levels.append(ep_coi / denom)
|
||||
|
||||
return {
|
||||
"eval/reward_mean": float(np.mean(rewards)) if rewards else 0.0,
|
||||
"eval/reward_std": float(np.std(rewards)) if rewards else 0.0,
|
||||
"eval/revenue_mean": float(np.mean(revenues)) if revenues else 0.0,
|
||||
"eval/revenue_std": float(np.std(revenues)) if revenues else 0.0,
|
||||
"eval/margin_mean": float(np.mean(margins)) if margins else 0.0,
|
||||
"eval/coi_level_mean": float(np.mean(coi_levels)) if coi_levels else 0.0,
|
||||
}
|
||||
18
engine/backends/jax.py
Normal file
18
engine/backends/jax.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..jax import JAX_AVAILABLE
|
||||
|
||||
|
||||
def train_jax_backend(
|
||||
cfg: Mapping[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
|
||||
if not JAX_AVAILABLE:
|
||||
raise ImportError(
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
from ..jax.train import train_jax
|
||||
|
||||
return train_jax(dict(cfg))
|
||||
53
engine/backends/qtable.py
Normal file
53
engine/backends/qtable.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
|
||||
from ..lib.discrete import EventQTable
|
||||
|
||||
np.random.seed(int(cfg["seed"]))
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
agent = EventQTable(
|
||||
env.action_space.n,
|
||||
int(cfg["n_products"]),
|
||||
(float(cfg["price_low"]), float(cfg["price_high"])),
|
||||
lr=float(cfg["q_lr"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
n_bins=int(cfg["q_bins"]),
|
||||
)
|
||||
|
||||
total_reward = 0.0
|
||||
total_revenue = 0.0
|
||||
steps = 0
|
||||
epsilon = float(cfg["eps_start"])
|
||||
obs, _ = env.reset(seed=int(cfg["seed"]))
|
||||
|
||||
for _ in range(int(cfg["total_timesteps"])):
|
||||
action, state = agent.act(obs, epsilon)
|
||||
nxt, reward, term, trunc, info = env.step(action)
|
||||
done = bool(term or trunc)
|
||||
agent.update(state, action, float(reward), agent.encode(nxt), done)
|
||||
|
||||
total_reward += float(reward)
|
||||
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
|
||||
steps += 1
|
||||
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
|
||||
obs = env.reset()[0] if done else nxt
|
||||
|
||||
metrics: dict[str, float | int] = {
|
||||
"train/reward_mean": total_reward / max(steps, 1),
|
||||
"train/revenue_mean": total_revenue / max(steps, 1),
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(cfg["total_timesteps"]),
|
||||
}
|
||||
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
return agent, metrics
|
||||
228
engine/backends/sb3.py
Normal file
228
engine/backends/sb3.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||
from ..telemetry.wandb import get_wandb_module
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
def _net_arch(name: Any) -> list[int]:
|
||||
presets = {
|
||||
"tiny": [32, 32],
|
||||
"small": [64, 64],
|
||||
"medium": [128, 128],
|
||||
"large": [256, 256],
|
||||
}
|
||||
if isinstance(name, (list, tuple)):
|
||||
return [int(v) for v in name]
|
||||
raw = str(name).lower().strip()
|
||||
if raw in presets:
|
||||
return presets[raw]
|
||||
if "x" in raw:
|
||||
try:
|
||||
parsed = [int(v) for v in raw.split("x") if v]
|
||||
return parsed if parsed else presets["small"]
|
||||
except ValueError:
|
||||
return presets["small"]
|
||||
return presets["small"]
|
||||
|
||||
|
||||
def _activation(name: Any):
|
||||
try:
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
return None
|
||||
return {
|
||||
"relu": nn.ReLU,
|
||||
"tanh": nn.Tanh,
|
||||
"elu": nn.ELU,
|
||||
"leaky_relu": nn.LeakyReLU,
|
||||
}.get(str(name).lower().strip(), nn.ReLU)
|
||||
|
||||
|
||||
def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {"net_arch": _net_arch(cfg.get("arch", "small"))}
|
||||
activation = _activation(cfg.get("activation", "relu"))
|
||||
if activation is not None:
|
||||
kwargs["activation_fn"] = activation
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sb3_model_cls(algo: str):
|
||||
try:
|
||||
from stable_baselines3 import A2C, DQN, PPO
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
|
||||
|
||||
if algo == "ppo":
|
||||
return PPO
|
||||
if algo == "a2c":
|
||||
return A2C
|
||||
if algo == "dqn":
|
||||
return DQN
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def build_model(cfg: Mapping[str, Any], env: Any):
|
||||
try:
|
||||
from stable_baselines3 import A2C, DQN, PPO
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
|
||||
|
||||
algo = str(cfg["algo"])
|
||||
policy_kwargs = _policy_kwargs(cfg)
|
||||
device = str(cfg.get("device", "auto"))
|
||||
seed = int(cfg["seed"])
|
||||
|
||||
if algo == "sac":
|
||||
raise ValueError("sac is not supported with the discrete core env")
|
||||
if algo == "ppo":
|
||||
return PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
device=device,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=seed,
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
n_steps=int(cfg["n_steps"]),
|
||||
batch_size=int(cfg["batch_size"]),
|
||||
n_epochs=int(cfg["n_epochs"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
gae_lambda=float(cfg["gae_lambda"]),
|
||||
clip_range=float(cfg["clip_range"]),
|
||||
ent_coef=float(cfg["ent_coef"]),
|
||||
)
|
||||
if algo == "a2c":
|
||||
return A2C(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
device=device,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=seed,
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
n_steps=max(5, int(cfg["n_steps"]) // 32),
|
||||
gamma=float(cfg["gamma"]),
|
||||
gae_lambda=float(cfg["gae_lambda"]),
|
||||
ent_coef=float(cfg["ent_coef"]),
|
||||
)
|
||||
if algo == "dqn":
|
||||
return DQN(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
device=device,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=seed,
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
buffer_size=int(cfg["buffer_size"]),
|
||||
batch_size=int(cfg["batch_size"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
train_freq=int(cfg["train_freq"]),
|
||||
learning_starts=int(cfg["learning_starts"]),
|
||||
target_update_interval=int(cfg["target_update_interval"]),
|
||||
exploration_fraction=float(cfg["exploration_fraction"]),
|
||||
exploration_final_eps=float(cfg["exploration_final_eps"]),
|
||||
)
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return model
|
||||
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||
if restored is None:
|
||||
return model
|
||||
|
||||
checkpoint_path, metadata = restored
|
||||
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
|
||||
checkpoint_path.as_posix(),
|
||||
env=env,
|
||||
)
|
||||
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
|
||||
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
|
||||
return resumed
|
||||
|
||||
|
||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
|
||||
try:
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models") from exc
|
||||
|
||||
env = Monitor(make_env(cfg))
|
||||
eval_env = Monitor(make_env(cfg))
|
||||
model = build_model(cfg, env)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
print(
|
||||
"PHANTOM_DEVICE: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"requested": str(cfg.get("device", "auto")),
|
||||
"torch_cuda_available": bool(torch.cuda.is_available()),
|
||||
"torch_device_count": int(torch.cuda.device_count()),
|
||||
"sb3_device": str(getattr(model, "device", "unknown")),
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = _maybe_resume_model(cfg, env, model)
|
||||
|
||||
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
|
||||
callbacks.append(
|
||||
CheckpointArtifactCallback(
|
||||
dict(cfg),
|
||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||
)
|
||||
)
|
||||
callbacks.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
eval_freq=int(cfg["eval_freq"]),
|
||||
n_eval_episodes=int(cfg["eval_episodes"]),
|
||||
deterministic=True,
|
||||
verbose=0,
|
||||
)
|
||||
)
|
||||
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
if remaining_steps > 0:
|
||||
model.learn(
|
||||
total_timesteps=remaining_steps,
|
||||
callback=callbacks,
|
||||
reset_num_timesteps=False,
|
||||
)
|
||||
|
||||
model_dir = Path(str(cfg["model_dir"]))
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = model_dir / f"phantom_{cfg['algo']}"
|
||||
model.save(str(model_path))
|
||||
|
||||
metrics: dict[str, float | int | str] = evaluate(
|
||||
model,
|
||||
eval_env,
|
||||
int(cfg["eval_episodes"]),
|
||||
)
|
||||
metrics["train/global_step"] = int(model.num_timesteps)
|
||||
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
return model, metrics
|
||||
456
engine/benchmark.py
Normal file
456
engine/benchmark.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, UTC
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .lib.tiers import LinearElasticityPolicy, StaticPolicy, SurgePolicy
|
||||
from .spec import TrainSpec
|
||||
from .telemetry.wandb import get_wandb_module
|
||||
|
||||
wandb = get_wandb_module()
|
||||
HAS_WANDB = wandb is not None
|
||||
|
||||
|
||||
def _parse_list(raw: str) -> list[str]:
|
||||
return [x.strip().lower() for x in str(raw).split(",") if x.strip()]
|
||||
|
||||
|
||||
def _parse_float_list(raw: str) -> list[float]:
|
||||
return [float(x.strip()) for x in str(raw).split(",") if x.strip()]
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _action(policy, obs: np.ndarray):
|
||||
out = policy.predict(obs, deterministic=True)
|
||||
action = out[0] if isinstance(out, tuple) else out
|
||||
if isinstance(action, np.ndarray) and action.size == 1:
|
||||
return int(action.reshape(-1)[0])
|
||||
return int(action)
|
||||
|
||||
|
||||
def _run_eval_episode(env, policy) -> dict:
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
total_revenue = 0.0
|
||||
total_margin = 0.0
|
||||
total_coi = 0.0
|
||||
price_trace: list[float] = []
|
||||
step_count = 0
|
||||
|
||||
while not done:
|
||||
action = _action(policy, obs)
|
||||
obs, reward, term, trunc, info = env.step(action)
|
||||
done = bool(term or trunc)
|
||||
econ = info.get("economics", {})
|
||||
total_reward += float(reward)
|
||||
total_revenue += float(econ.get("revenue", 0.0))
|
||||
total_margin += float(econ.get("margin", 0.0))
|
||||
total_coi += float(econ.get("coi_level", 0.0))
|
||||
prices = np.asarray(info.get("prices", []), dtype=np.float32)
|
||||
if prices.size > 0:
|
||||
price_trace.append(float(np.mean(prices)))
|
||||
step_count += 1
|
||||
|
||||
denom = max(step_count, 1)
|
||||
return {
|
||||
"reward": total_reward,
|
||||
"revenue": total_revenue,
|
||||
"mean_margin": total_margin / denom,
|
||||
"mean_coi": total_coi / denom,
|
||||
"price_trace": price_trace,
|
||||
}
|
||||
|
||||
|
||||
def _build_tier(name: str, cfg: dict, alpha: float):
|
||||
from .backends.common import make_env
|
||||
from .backends.qtable import train_qtable
|
||||
from .backends.sb3 import train_sb3
|
||||
|
||||
tier = name.lower().strip()
|
||||
run_cfg = dict(cfg)
|
||||
run_cfg["alpha"] = float(alpha)
|
||||
|
||||
if tier == "static":
|
||||
return StaticPolicy(int(run_cfg["action_levels"]))
|
||||
|
||||
if tier == "surge":
|
||||
return SurgePolicy(
|
||||
n_actions=int(run_cfg["action_levels"]),
|
||||
n_products=int(run_cfg["n_products"]),
|
||||
)
|
||||
|
||||
if tier == "linear":
|
||||
warmup_env = make_env(run_cfg)
|
||||
policy = LinearElasticityPolicy(
|
||||
n_actions=int(run_cfg["action_levels"]),
|
||||
n_products=int(run_cfg["n_products"]),
|
||||
price_low=float(run_cfg["price_low"]),
|
||||
price_high=float(run_cfg["price_high"]),
|
||||
)
|
||||
policy.fit(
|
||||
warmup_env,
|
||||
warmup_steps=int(run_cfg.get("linear_warmup_steps", 800)),
|
||||
seed=int(run_cfg["seed"]),
|
||||
)
|
||||
warmup_env.close()
|
||||
return policy
|
||||
|
||||
if tier == "qtable":
|
||||
agent, _ = train_qtable(run_cfg)
|
||||
return agent
|
||||
|
||||
if tier in {"ppo", "a2c", "dqn"}:
|
||||
run_cfg["algo"] = tier
|
||||
agent, _ = train_sb3(run_cfg)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"unsupported tier '{name}'")
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
cfg: dict, tiers: list[str], alpha_values: list[float], n_episodes: int
|
||||
):
|
||||
from .backends.common import make_env
|
||||
|
||||
rows: list[dict] = []
|
||||
traces: list[dict] = []
|
||||
|
||||
for alpha in alpha_values:
|
||||
for tier_name in tiers:
|
||||
policy = _build_tier(tier_name, cfg, alpha)
|
||||
env = make_env({**cfg, "alpha": float(alpha)})
|
||||
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
|
||||
env.close()
|
||||
|
||||
row = {
|
||||
"tier": tier_name,
|
||||
"alpha": float(alpha),
|
||||
"episodes": int(n_episodes),
|
||||
"mean_reward": float(np.mean([e["reward"] for e in eps])),
|
||||
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
|
||||
"mean_margin": float(np.mean([e["mean_margin"] for e in eps])),
|
||||
"mean_coi": float(np.mean([e["mean_coi"] for e in eps])),
|
||||
"std_revenue": float(np.std([e["revenue"] for e in eps])),
|
||||
}
|
||||
row["objective_score"] = (
|
||||
row["mean_reward"]
|
||||
+ float(cfg.get("revenue_weight", 0.01)) * row["mean_revenue"]
|
||||
)
|
||||
rows.append(row)
|
||||
|
||||
max_len = max((len(e["price_trace"]) for e in eps), default=0)
|
||||
step_means = []
|
||||
for step in range(max_len):
|
||||
vals = [
|
||||
e["price_trace"][step] for e in eps if step < len(e["price_trace"])
|
||||
]
|
||||
step_means.append(float(np.mean(vals)) if vals else np.nan)
|
||||
traces.append(
|
||||
{
|
||||
"tier": tier_name,
|
||||
"alpha": float(alpha),
|
||||
"mean_price_trace": step_means,
|
||||
}
|
||||
)
|
||||
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"study/alpha": float(alpha),
|
||||
"eval/reward_mean": row["mean_reward"],
|
||||
"eval/revenue_mean": row["mean_revenue"],
|
||||
"eval/margin_mean": row["mean_margin"],
|
||||
"objective/score": row["objective_score"],
|
||||
"objective/coi_preserved": row["mean_coi"],
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(rows), traces
|
||||
|
||||
|
||||
def _plot_outputs(df: pd.DataFrame, traces: list[dict], out_dir: Path, stamp: str):
|
||||
fig1 = plt.figure(figsize=(11, 4.5))
|
||||
if "mode" in df.columns:
|
||||
groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist())
|
||||
for tier, mode in groups:
|
||||
sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha")
|
||||
plt.plot(
|
||||
sub["alpha"],
|
||||
sub["mean_revenue"],
|
||||
marker="o",
|
||||
label=f"{tier}:{mode}",
|
||||
)
|
||||
else:
|
||||
for tier in sorted(df["tier"].unique()):
|
||||
sub = df[df["tier"] == tier].sort_values("alpha")
|
||||
plt.plot(sub["alpha"], sub["mean_revenue"], marker="o", label=tier)
|
||||
plt.xlabel("contamination alpha")
|
||||
plt.ylabel("mean episode revenue")
|
||||
plt.title("Revenue under contamination")
|
||||
plt.grid(alpha=0.3)
|
||||
plt.legend()
|
||||
fig1.tight_layout()
|
||||
rev_path = out_dir / f"benchmark_revenue_{stamp}.png"
|
||||
fig1.savefig(rev_path, dpi=220)
|
||||
plt.close(fig1)
|
||||
|
||||
fig2 = plt.figure(figsize=(11, 4.5))
|
||||
if "mode" in df.columns:
|
||||
groups = sorted(df[["tier", "mode"]].drop_duplicates().values.tolist())
|
||||
for tier, mode in groups:
|
||||
sub = df[(df["tier"] == tier) & (df["mode"] == mode)].sort_values("alpha")
|
||||
plt.plot(
|
||||
sub["alpha"],
|
||||
sub["mean_coi"],
|
||||
marker="s",
|
||||
label=f"{tier}:{mode}",
|
||||
)
|
||||
else:
|
||||
for tier in sorted(df["tier"].unique()):
|
||||
sub = df[df["tier"] == tier].sort_values("alpha")
|
||||
plt.plot(sub["alpha"], sub["mean_coi"], marker="s", label=tier)
|
||||
plt.xlabel("contamination alpha")
|
||||
plt.ylabel("mean COI level")
|
||||
plt.title("COI preservation")
|
||||
plt.grid(alpha=0.3)
|
||||
plt.legend()
|
||||
fig2.tight_layout()
|
||||
coi_path = out_dir / f"benchmark_coi_{stamp}.png"
|
||||
fig2.savefig(coi_path, dpi=220)
|
||||
plt.close(fig2)
|
||||
|
||||
focus_alpha = float(df["alpha"].min()) if not df.empty else 0.0
|
||||
alpha_traces = [t for t in traces if abs(float(t["alpha"]) - focus_alpha) < 1e-9]
|
||||
fig3 = plt.figure(figsize=(11, 4.5))
|
||||
for item in alpha_traces:
|
||||
xs = np.arange(len(item["mean_price_trace"]))
|
||||
ys = np.asarray(item["mean_price_trace"], dtype=np.float32)
|
||||
mode = item.get("mode")
|
||||
label = f"{item['tier']}:{mode}" if mode is not None else str(item["tier"])
|
||||
plt.plot(xs, ys, label=label)
|
||||
plt.xlabel("step")
|
||||
plt.ylabel("mean price")
|
||||
plt.title(f"Price evolution (alpha={focus_alpha:.2f})")
|
||||
plt.grid(alpha=0.3)
|
||||
plt.legend()
|
||||
fig3.tight_layout()
|
||||
price_path = out_dir / f"benchmark_price_trace_{stamp}.png"
|
||||
fig3.savefig(price_path, dpi=220)
|
||||
plt.close(fig3)
|
||||
|
||||
return rev_path, coi_path, price_path
|
||||
|
||||
|
||||
def _run_with_args(args):
|
||||
compare_robust = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
|
||||
robust_modes = [False, True] if compare_robust else [bool(args.no_robust)]
|
||||
|
||||
base_overrides = {
|
||||
"seed": args.seed,
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"n_products": args.n_products,
|
||||
"N": args.N,
|
||||
"lambda_coi": args.lambda_coi,
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"price_low": args.price_low,
|
||||
"price_high": args.price_high,
|
||||
"action_levels": args.action_levels,
|
||||
"action_scale_low": args.action_scale_low,
|
||||
"action_scale_high": args.action_scale_high,
|
||||
"max_steps": args.max_steps,
|
||||
"learning_rate": args.learning_rate,
|
||||
"batch_size": args.batch_size,
|
||||
"n_steps": args.n_steps,
|
||||
"linear_warmup_steps": args.linear_warmup_steps,
|
||||
"device": args.device,
|
||||
}
|
||||
tiers = _parse_list(args.tiers)
|
||||
alpha_values = _parse_float_list(args.alpha_values)
|
||||
|
||||
all_frames: list[pd.DataFrame] = []
|
||||
all_traces: list[dict] = []
|
||||
for no_robust in robust_modes:
|
||||
overrides = dict(base_overrides)
|
||||
overrides["no_robust"] = bool(no_robust)
|
||||
cfg = TrainSpec.from_flat(
|
||||
{k: v for k, v in overrides.items() if v is not None}
|
||||
).to_flat_dict()
|
||||
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
|
||||
df_mode, traces_mode = run_benchmark(cfg, tiers, alpha_values, args.episodes)
|
||||
mode_label = "no_robust" if no_robust else "robust"
|
||||
df_mode["mode"] = mode_label
|
||||
for trace in traces_mode:
|
||||
trace["mode"] = mode_label
|
||||
all_frames.append(df_mode)
|
||||
all_traces.extend(traces_mode)
|
||||
|
||||
df = pd.concat(all_frames, ignore_index=True) if all_frames else pd.DataFrame()
|
||||
traces = all_traces
|
||||
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
stamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
csv_path = out_dir / f"benchmark_{stamp}.csv"
|
||||
trace_path = out_dir / f"benchmark_traces_{stamp}.json"
|
||||
df.to_csv(csv_path, index=False)
|
||||
trace_path.write_text(json.dumps(traces, indent=2))
|
||||
rev_path, coi_path, price_path = _plot_outputs(df, traces, out_dir, stamp)
|
||||
|
||||
if not df.empty:
|
||||
best_idx = int(df["mean_revenue"].idxmax())
|
||||
best = df.iloc[best_idx]
|
||||
print(
|
||||
"BEST_TIER="
|
||||
+ json.dumps(
|
||||
{
|
||||
"tier": best["tier"],
|
||||
"mode": best.get("mode", "robust"),
|
||||
"alpha": float(best["alpha"]),
|
||||
"mean_revenue": float(best["mean_revenue"]),
|
||||
"mean_coi": float(best["mean_coi"]),
|
||||
}
|
||||
)
|
||||
)
|
||||
print(f"BENCHMARK_CSV={csv_path}")
|
||||
print(f"BENCHMARK_TRACES={trace_path}")
|
||||
print(f"BENCHMARK_PLOT_REVENUE={rev_path}")
|
||||
print(f"BENCHMARK_PLOT_COI={coi_path}")
|
||||
print(f"BENCHMARK_PLOT_PRICE={price_path}")
|
||||
|
||||
|
||||
def run_cli(raw_args: list[str] | None = None):
|
||||
parser = argparse.ArgumentParser(description="PHANTOM benchmark orchestrator")
|
||||
parser.add_argument("--project", default="capstone")
|
||||
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")
|
||||
parser.add_argument("--alpha-values", default="0.0,0.3,0.6")
|
||||
parser.add_argument("--episodes", type=int, default=10)
|
||||
parser.add_argument("--output-dir", default="engine/studies/results")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--total-timesteps", type=int, default=25_000)
|
||||
parser.add_argument("--n-products", type=int, default=10)
|
||||
parser.add_argument("--N", type=int, default=100)
|
||||
parser.add_argument("--lambda-coi", type=float, default=0.2)
|
||||
parser.add_argument("--robust-radius", type=float, default=0.15)
|
||||
parser.add_argument("--robust-points", type=int, default=5)
|
||||
parser.add_argument("--price-low", type=float, default=10.0)
|
||||
parser.add_argument("--price-high", type=float, default=150.0)
|
||||
parser.add_argument("--action-levels", type=int, default=9)
|
||||
parser.add_argument("--action-scale-low", type=float, default=0.8)
|
||||
parser.add_argument("--action-scale-high", type=float, default=1.2)
|
||||
parser.add_argument("--max-steps", type=int, default=100)
|
||||
parser.add_argument("--learning-rate", type=float, default=3e-4)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--n-steps", type=int, default=2048)
|
||||
parser.add_argument("--linear-warmup-steps", type=int, default=800)
|
||||
parser.add_argument("--device", type=str, default="auto")
|
||||
parser.add_argument("--no-robust", action="store_true")
|
||||
parser.add_argument("--no-wandb", action="store_true")
|
||||
parser.add_argument("--offline", action="store_true")
|
||||
parser.add_argument("--sweep-agent", action="store_true")
|
||||
parser.add_argument("--sweep-id", type=str)
|
||||
parser.add_argument("--count", type=int, default=0)
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
if args.sweep_agent:
|
||||
if args.no_wandb or not HAS_WANDB:
|
||||
raise ValueError("sweep agent requires wandb")
|
||||
if not args.sweep_id:
|
||||
raise ValueError("--sweep-id is required with --sweep-agent")
|
||||
|
||||
def _sweep_run():
|
||||
run = wandb.init(mode="offline" if args.offline else "online")
|
||||
try:
|
||||
key_to_attr = {
|
||||
"tiers": "tiers",
|
||||
"alpha_values": "alpha_values",
|
||||
"episodes": "episodes",
|
||||
"total_timesteps": "total_timesteps",
|
||||
"lambda_coi": "lambda_coi",
|
||||
"robust_radius": "robust_radius",
|
||||
"robust_points": "robust_points",
|
||||
"learning_rate": "learning_rate",
|
||||
"batch_size": "batch_size",
|
||||
"n_steps": "n_steps",
|
||||
"no_robust": "no_robust",
|
||||
"device": "device",
|
||||
}
|
||||
for key in (
|
||||
"tiers",
|
||||
"alpha_values",
|
||||
"episodes",
|
||||
"total_timesteps",
|
||||
"lambda_coi",
|
||||
"robust_radius",
|
||||
"robust_points",
|
||||
"learning_rate",
|
||||
"batch_size",
|
||||
"n_steps",
|
||||
"no_robust",
|
||||
"device",
|
||||
):
|
||||
if key in wandb.config:
|
||||
setattr(args, key_to_attr[key], wandb.config[key])
|
||||
_run_with_args(args)
|
||||
finally:
|
||||
if run is not None:
|
||||
wandb.finish()
|
||||
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
function=_sweep_run,
|
||||
count=args.count if args.count > 0 else None,
|
||||
)
|
||||
return
|
||||
|
||||
if args.no_wandb or not HAS_WANDB:
|
||||
_run_with_args(args)
|
||||
return
|
||||
|
||||
run = wandb.init(
|
||||
project=args.project,
|
||||
name=f"benchmark-{datetime.now(UTC).strftime('%m%d-%H%M%S')}",
|
||||
tags=[
|
||||
"benchmark",
|
||||
"robust-compare"
|
||||
if _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
|
||||
else "single-mode",
|
||||
],
|
||||
config={
|
||||
"run.kind": "benchmark",
|
||||
"tiers": args.tiers,
|
||||
"alpha_values": args.alpha_values,
|
||||
"episodes": args.episodes,
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"lambda_coi": args.lambda_coi,
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"learning_rate": args.learning_rate,
|
||||
"device": args.device,
|
||||
},
|
||||
mode="offline" if args.offline else "online",
|
||||
)
|
||||
try:
|
||||
_run_with_args(args)
|
||||
finally:
|
||||
if run is not None:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_cli()
|
||||
@@ -624,8 +624,8 @@ def evaluate_policy(
|
||||
revenues.append(ep_revenue)
|
||||
|
||||
return {
|
||||
"eval/reward": float(np.mean(rewards)),
|
||||
"eval/revenue": float(np.mean(revenues)),
|
||||
"eval/reward_mean": float(np.mean(rewards)),
|
||||
"eval/revenue_mean": float(np.mean(revenues)),
|
||||
"eval/reward_std": float(np.std(rewards)),
|
||||
"eval/revenue_std": float(np.std(revenues)),
|
||||
}
|
||||
@@ -665,8 +665,8 @@ def _evaluate_q_network(
|
||||
revenues.append(ep_revenue)
|
||||
|
||||
return {
|
||||
"eval/reward": float(np.mean(rewards)),
|
||||
"eval/revenue": float(np.mean(revenues)),
|
||||
"eval/reward_mean": float(np.mean(rewards)),
|
||||
"eval/revenue_mean": float(np.mean(revenues)),
|
||||
"eval/reward_std": float(np.std(rewards)),
|
||||
"eval/revenue_std": float(np.std(revenues)),
|
||||
}
|
||||
@@ -713,8 +713,8 @@ def _evaluate_q_table(
|
||||
revenues.append(ep_revenue)
|
||||
|
||||
return {
|
||||
"eval/reward": float(np.mean(rewards)),
|
||||
"eval/revenue": float(np.mean(revenues)),
|
||||
"eval/reward_mean": float(np.mean(rewards)),
|
||||
"eval/revenue_mean": float(np.mean(revenues)),
|
||||
"eval/reward_std": float(np.std(rewards)),
|
||||
"eval/revenue_std": float(np.std(revenues)),
|
||||
}
|
||||
@@ -831,8 +831,8 @@ def _train_actor_critic(
|
||||
if is_primary and HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": float(segment_values["reward"].mean()),
|
||||
"train/revenue": float(segment_values["revenue"].mean()),
|
||||
"train/reward_mean": float(segment_values["reward"].mean()),
|
||||
"train/revenue_mean": float(segment_values["revenue"].mean()),
|
||||
"train/agent_prob": float(segment_values["agent_prob"].mean()),
|
||||
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
|
||||
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
|
||||
@@ -873,8 +873,8 @@ def _train_actor_critic(
|
||||
train_state = final_runner[0]
|
||||
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||
metrics = {
|
||||
"train/reward": float(metric_sums["reward"] / denom),
|
||||
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||
"train/reward_mean": float(metric_sums["reward"] / denom),
|
||||
"train/revenue_mean": float(metric_sums["revenue"] / denom),
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
@@ -1052,14 +1052,14 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
):
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": metric_sums["reward"] / max(metric_count, 1),
|
||||
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
|
||||
"train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
|
||||
"train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
|
||||
"train/agent_prob": metric_sums["agent_prob"]
|
||||
/ max(metric_count, 1),
|
||||
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
|
||||
"train/coi_leakage": metric_sums["coi_leakage"]
|
||||
/ max(metric_count, 1),
|
||||
"train/dqn_loss": metric_sums["loss"] / max(loss_count, 1),
|
||||
"train/loss": metric_sums["loss"] / max(loss_count, 1),
|
||||
"train/epsilon": epsilon_value,
|
||||
"train/global_step": global_step,
|
||||
},
|
||||
@@ -1090,12 +1090,12 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
|
||||
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||
metrics = {
|
||||
"train/reward": float(metric_sums["reward"] / denom),
|
||||
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||
"train/reward_mean": float(metric_sums["reward"] / denom),
|
||||
"train/revenue_mean": float(metric_sums["revenue"] / denom),
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
"train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)),
|
||||
"train/loss": float(metric_sums["loss"] / max(loss_count, 1)),
|
||||
"train/global_step": total_steps,
|
||||
}
|
||||
|
||||
@@ -1236,8 +1236,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
|
||||
):
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": metric_sums["reward"] / max(metric_count, 1),
|
||||
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
|
||||
"train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
|
||||
"train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
|
||||
"train/agent_prob": metric_sums["agent_prob"]
|
||||
/ max(metric_count, 1),
|
||||
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
|
||||
@@ -1269,8 +1269,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
|
||||
|
||||
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||
metrics = {
|
||||
"train/reward": float(metric_sums["reward"] / denom),
|
||||
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||
"train/reward_mean": float(metric_sums["reward"] / denom),
|
||||
"train/revenue_mean": float(metric_sums["revenue"] / denom),
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
|
||||
@@ -1,38 +1,39 @@
|
||||
from .demand import estimate_demand, estimate_weighted_demand, generate_demand_for_actor
|
||||
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
|
||||
from .render import DashboardRenderer, style_axis
|
||||
from .wrappers import EconomicMetricsWrapper
|
||||
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
|
||||
from .providers import (
|
||||
ProviderBenchmark,
|
||||
ProviderResult,
|
||||
BenchmarkConfig,
|
||||
RandomBaseline,
|
||||
SurgeBaseline,
|
||||
)
|
||||
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
|
||||
from .discrete import EventQTable
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"estimate_demand",
|
||||
"estimate_weighted_demand",
|
||||
"generate_demand_for_actor",
|
||||
"sample_behavior",
|
||||
"get_transition_models",
|
||||
"trajectory_to_events",
|
||||
"DashboardRenderer",
|
||||
"style_axis",
|
||||
"EconomicMetricsWrapper",
|
||||
"MetricsCallback",
|
||||
"EvalMetricsCallback",
|
||||
"CheckpointArtifactCallback",
|
||||
"ProviderBenchmark",
|
||||
"ProviderResult",
|
||||
"BenchmarkConfig",
|
||||
"RandomBaseline",
|
||||
"SurgeBaseline",
|
||||
"compute_uplift_coi",
|
||||
"extract_purchases",
|
||||
"compute_agent_probability",
|
||||
"EventQTable",
|
||||
]
|
||||
from importlib import import_module
|
||||
|
||||
_EXPORTS: dict[str, tuple[str, str]] = {
|
||||
"estimate_demand": (".demand", "estimate_demand"),
|
||||
"estimate_weighted_demand": (".demand", "estimate_weighted_demand"),
|
||||
"generate_demand_for_actor": (".demand", "generate_demand_for_actor"),
|
||||
"sample_behavior": (".behavior", "sample_behavior"),
|
||||
"get_transition_models": (".behavior", "get_transition_models"),
|
||||
"trajectory_to_events": (".behavior", "trajectory_to_events"),
|
||||
"DashboardRenderer": (".render", "DashboardRenderer"),
|
||||
"style_axis": (".render", "style_axis"),
|
||||
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
|
||||
"MetricsCallback": (".callbacks", "MetricsCallback"),
|
||||
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
|
||||
"CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
|
||||
"ProviderBenchmark": (".providers", "ProviderBenchmark"),
|
||||
"ProviderResult": (".providers", "ProviderResult"),
|
||||
"BenchmarkConfig": (".providers", "BenchmarkConfig"),
|
||||
"RandomBaseline": (".providers", "RandomBaseline"),
|
||||
"SurgeBaseline": (".providers", "SurgeBaseline"),
|
||||
"compute_uplift_coi": (".coi", "compute_uplift_coi"),
|
||||
"extract_purchases": (".coi", "extract_purchases"),
|
||||
"compute_agent_probability": (".coi", "compute_agent_probability"),
|
||||
"EventQTable": (".discrete", "EventQTable"),
|
||||
}
|
||||
|
||||
__all__ = sorted(_EXPORTS)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name not in _EXPORTS:
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
module_name, attr_name = _EXPORTS[name]
|
||||
module = import_module(module_name, package=__name__)
|
||||
value = getattr(module, attr_name)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
@@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback):
|
||||
t = self.num_timesteps
|
||||
|
||||
payload = {
|
||||
"economics/revenue": econ["revenue"],
|
||||
"economics/margin": econ["margin"],
|
||||
"coi/level": econ["coi_level"],
|
||||
"economics/regret": econ["regret"],
|
||||
"train/revenue_step": econ["revenue"],
|
||||
"train/margin_step": econ["margin"],
|
||||
"train/coi_level": econ["coi_level"],
|
||||
"train/regret_step": econ["regret"],
|
||||
}
|
||||
if "coi_mix" in econ:
|
||||
payload["coi/mix"] = econ["coi_mix"]
|
||||
payload["train/coi_mix"] = econ["coi_mix"]
|
||||
if "coi_base" in econ:
|
||||
payload["coi/base"] = econ["coi_base"]
|
||||
payload["train/coi_base"] = econ["coi_base"]
|
||||
if "coi_leakage" in econ:
|
||||
payload["coi/leakage"] = econ["coi_leakage"]
|
||||
payload["train/coi_leakage"] = econ["coi_leakage"]
|
||||
if "coi_penalty" in econ:
|
||||
payload["coi/penalty"] = econ["coi_penalty"]
|
||||
payload["train/coi_penalty"] = econ["coi_penalty"]
|
||||
wandb.log(payload, step=t)
|
||||
|
||||
self._episode_revenues.append(econ["revenue"])
|
||||
@@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback):
|
||||
return
|
||||
wandb.log(
|
||||
{
|
||||
"episode/mean_revenue": np.mean(self._episode_revenues),
|
||||
"episode/total_revenue": np.sum(self._episode_revenues),
|
||||
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
|
||||
"train/revenue_rollout_total": np.sum(self._episode_revenues),
|
||||
},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
@@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback):
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
wandb.log(
|
||||
{
|
||||
"eval/mean_reward": self.last_mean_reward,
|
||||
"eval/mean_revenue": np.mean(self._eval_revenues)
|
||||
"eval/reward_mean": self.last_mean_reward,
|
||||
"eval/revenue_mean": np.mean(self._eval_revenues)
|
||||
if self._eval_revenues
|
||||
else 0,
|
||||
},
|
||||
|
||||
101
engine/lib/tiers.py
Normal file
101
engine/lib/tiers.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PolicyLike(Protocol):
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True): ...
|
||||
|
||||
|
||||
class StaticPolicy:
|
||||
def __init__(self, n_actions: int):
|
||||
self._action = int(max(0, n_actions // 2))
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
return self._action, None
|
||||
|
||||
|
||||
class SurgePolicy:
|
||||
def __init__(
|
||||
self,
|
||||
n_actions: int,
|
||||
n_products: int,
|
||||
high_threshold: float = 60.0,
|
||||
low_threshold: float = 30.0,
|
||||
):
|
||||
self.n_actions = int(n_actions)
|
||||
self.n_products = int(n_products)
|
||||
self.mid = self.n_actions // 2
|
||||
self.high_t = float(high_threshold)
|
||||
self.low_t = float(low_threshold)
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
obs_arr = np.asarray(obs, dtype=np.float32)
|
||||
demand = obs_arr[: self.n_products]
|
||||
demand_mean = float(np.mean(demand)) if demand.size > 0 else 0.0
|
||||
if demand_mean >= self.high_t:
|
||||
return min(self.mid + 2, self.n_actions - 1), None
|
||||
if demand_mean <= self.low_t:
|
||||
return max(self.mid - 2, 0), None
|
||||
return self.mid, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearElasticityPolicy:
|
||||
n_actions: int
|
||||
n_products: int
|
||||
price_low: float
|
||||
price_high: float
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_actions = int(self.n_actions)
|
||||
self.n_products = int(self.n_products)
|
||||
self.price_low = float(self.price_low)
|
||||
self.price_high = float(self.price_high)
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
self._action_scales = np.linspace(0.8, 1.2, self.n_actions)
|
||||
|
||||
def fit(self, env, warmup_steps: int = 800, seed: int = 42):
|
||||
rng = np.random.default_rng(int(seed))
|
||||
obs, _ = env.reset(seed=int(seed))
|
||||
prices: list[float] = []
|
||||
demands: list[float] = []
|
||||
|
||||
for _ in range(int(max(10, warmup_steps))):
|
||||
action = int(rng.integers(0, self.n_actions))
|
||||
obs, _, term, trunc, info = env.step(action)
|
||||
done = bool(term or trunc)
|
||||
|
||||
p = np.asarray(info.get("prices", []), dtype=np.float32)
|
||||
d = np.asarray(info.get("demand", []), dtype=np.float32)
|
||||
if p.size > 0 and d.size > 0:
|
||||
prices.append(float(np.mean(p)))
|
||||
demands.append(float(np.mean(d)))
|
||||
|
||||
if done:
|
||||
obs, _ = env.reset()
|
||||
|
||||
if len(prices) < 8:
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
return self
|
||||
|
||||
slope, intercept = np.polyfit(np.asarray(prices), np.asarray(demands), 1)
|
||||
if slope < -1e-6:
|
||||
p_star = -intercept / (2.0 * slope)
|
||||
self._target_price = float(np.clip(p_star, self.price_low, self.price_high))
|
||||
else:
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
return self
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
obs_arr = np.asarray(obs, dtype=np.float32)
|
||||
cur_prices = obs_arr[self.n_products : 2 * self.n_products]
|
||||
cur_mean = (
|
||||
float(np.mean(cur_prices)) if cur_prices.size > 0 else self._target_price
|
||||
)
|
||||
scale = self._target_price / max(cur_mean, 1e-6)
|
||||
action = int(np.argmin(np.abs(self._action_scales - scale)))
|
||||
return int(np.clip(action, 0, self.n_actions - 1)), None
|
||||
5
engine/orchestrators/__init__.py
Normal file
5
engine/orchestrators/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .benchmark import run_benchmark_cli
|
||||
from .sweep_agent import run_sweep_agent
|
||||
from .train import run_train_once
|
||||
|
||||
__all__ = ["run_benchmark_cli", "run_sweep_agent", "run_train_once"]
|
||||
7
engine/orchestrators/benchmark.py
Normal file
7
engine/orchestrators/benchmark.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def run_benchmark_cli(raw_args: list[str] | None = None) -> None:
|
||||
from ..benchmark import run_cli
|
||||
|
||||
run_cli(raw_args)
|
||||
60
engine/orchestrators/sweep_agent.py
Normal file
60
engine/orchestrators/sweep_agent.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from ..spec import TrainSpec, run_name
|
||||
from ..telemetry.wandb import (
|
||||
current_config,
|
||||
finish_run,
|
||||
get_wandb_module,
|
||||
init_run,
|
||||
run_agent,
|
||||
)
|
||||
from .train import run_with_active_sweep_run
|
||||
|
||||
|
||||
def run_sweep_agent(
|
||||
*,
|
||||
project: str,
|
||||
sweep_id: str,
|
||||
count: int,
|
||||
offline: bool,
|
||||
no_wandb: bool,
|
||||
base_overrides: Mapping[str, Any],
|
||||
kind: str,
|
||||
scenario: str,
|
||||
group: str | None,
|
||||
extra_tags: Sequence[str],
|
||||
) -> None:
|
||||
if no_wandb:
|
||||
raise ValueError("sweep agent requires wandb")
|
||||
if not sweep_id:
|
||||
raise ValueError("--sweep-id is required with --sweep-agent")
|
||||
if get_wandb_module() is None:
|
||||
raise ImportError("wandb is required for sweep runs")
|
||||
|
||||
mode = "offline" if offline else "online"
|
||||
|
||||
def _sweep_trial() -> None:
|
||||
run = init_run(mode=mode, project=project, group=group, sweep_mode=True)
|
||||
try:
|
||||
merged = dict(base_overrides)
|
||||
merged.update(current_config())
|
||||
spec = TrainSpec.from_flat(merged)
|
||||
if run is not None:
|
||||
run.name = run_name(spec, kind=kind, scenario=scenario)
|
||||
run_with_active_sweep_run(
|
||||
spec,
|
||||
kind=kind,
|
||||
scenario=scenario,
|
||||
group=group,
|
||||
extra_tags=extra_tags,
|
||||
)
|
||||
finally:
|
||||
finish_run()
|
||||
|
||||
run_agent(
|
||||
sweep_id,
|
||||
_sweep_trial,
|
||||
count=count if count > 0 else None,
|
||||
)
|
||||
129
engine/orchestrators/train.py
Normal file
129
engine/orchestrators/train.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Sequence
|
||||
|
||||
from ..spec import TrainSpec, run_metadata, run_name
|
||||
from ..telemetry.wandb import (
|
||||
finish_run,
|
||||
get_wandb_module,
|
||||
init_run,
|
||||
log_metrics,
|
||||
update_run_config,
|
||||
update_summary,
|
||||
)
|
||||
from ..train_core import run_train
|
||||
|
||||
|
||||
def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list[str]:
|
||||
tags = [
|
||||
kind,
|
||||
spec.algorithm.name,
|
||||
spec.runtime.backend,
|
||||
"vanilla" if spec.study.no_robust else "robust",
|
||||
]
|
||||
tags.extend([tag for tag in extra_tags if tag])
|
||||
return tags
|
||||
|
||||
|
||||
def _print_local_metrics(metrics: dict[str, Any]) -> None:
|
||||
print(json.dumps(metrics, indent=2))
|
||||
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
||||
|
||||
|
||||
def _should_print_local(spec: TrainSpec) -> bool:
|
||||
if not spec.runtime.use_jax:
|
||||
return True
|
||||
try:
|
||||
import jax
|
||||
|
||||
return int(jax.process_index()) == 0
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
|
||||
if not spec.runtime.use_jax:
|
||||
return False
|
||||
try:
|
||||
import jax
|
||||
|
||||
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def run_train_once(
|
||||
spec: TrainSpec,
|
||||
*,
|
||||
project: str,
|
||||
offline: bool,
|
||||
no_wandb: bool,
|
||||
kind: str,
|
||||
scenario: str,
|
||||
group: str | None,
|
||||
extra_tags: Sequence[str],
|
||||
) -> dict[str, Any]:
|
||||
wandb = get_wandb_module()
|
||||
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
|
||||
result = run_train(spec)
|
||||
if _should_print_local(spec):
|
||||
_print_local_metrics(result.metrics)
|
||||
return result.metrics
|
||||
|
||||
mode = "offline" if offline else "online"
|
||||
tags = _tags_for_run(spec, kind, extra_tags)
|
||||
metadata = run_metadata(
|
||||
spec,
|
||||
kind=kind,
|
||||
scenario=scenario,
|
||||
group=group,
|
||||
tags=tags,
|
||||
)
|
||||
config = spec.to_flat_dict()
|
||||
config.update(metadata)
|
||||
name = run_name(spec, kind=kind, scenario=scenario)
|
||||
init_run(
|
||||
mode=mode,
|
||||
project=project,
|
||||
config=config,
|
||||
name=name,
|
||||
tags=tags,
|
||||
group=group,
|
||||
sweep_mode=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = run_train(spec)
|
||||
metrics = result.metrics
|
||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||
log_metrics(metrics, step=step)
|
||||
update_summary(metrics)
|
||||
return metrics
|
||||
finally:
|
||||
finish_run()
|
||||
|
||||
|
||||
def run_with_active_sweep_run(
|
||||
spec: TrainSpec,
|
||||
*,
|
||||
kind: str,
|
||||
scenario: str,
|
||||
group: str | None,
|
||||
extra_tags: Sequence[str],
|
||||
) -> dict[str, Any]:
|
||||
tags = _tags_for_run(spec, kind, extra_tags)
|
||||
metadata = run_metadata(
|
||||
spec,
|
||||
kind=kind,
|
||||
scenario=scenario,
|
||||
group=group,
|
||||
tags=tags,
|
||||
)
|
||||
update_run_config({**spec.to_flat_dict(), **metadata})
|
||||
result = run_train(spec)
|
||||
metrics = result.metrics
|
||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||
log_metrics(metrics, step=step)
|
||||
update_summary(metrics)
|
||||
return metrics
|
||||
@@ -31,6 +31,26 @@
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"benchmark": {
|
||||
"executor": "nx:run-commands",
|
||||
"dependsOn": [
|
||||
"install"
|
||||
],
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh benchmark",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"benchmark-agent": {
|
||||
"executor": "nx:run-commands",
|
||||
"dependsOn": [
|
||||
"install"
|
||||
],
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh benchmark-agent",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"train-agent": {
|
||||
"executor": "nx:run-commands",
|
||||
"dependsOn": [
|
||||
|
||||
340
engine/spec.py
Normal file
340
engine/spec.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
alias_map = {
|
||||
"algorithm": "algo",
|
||||
"algorithm.name": "algo",
|
||||
"env.n_products": "n_products",
|
||||
"env.action_levels": "action_levels",
|
||||
"env.action_scale_low": "action_scale_low",
|
||||
"env.action_scale_high": "action_scale_high",
|
||||
"env.price_low": "price_low",
|
||||
"env.price_high": "price_high",
|
||||
"env.max_steps": "max_steps",
|
||||
"env.margin_floor": "margin_floor",
|
||||
"env.margin_floor_patience": "margin_floor_patience",
|
||||
"env.n_sessions": "N",
|
||||
"study.alpha": "alpha",
|
||||
"study.lambda_coi": "lambda_coi",
|
||||
"study.robust_radius": "robust_radius",
|
||||
"study.robust_points": "robust_points",
|
||||
"study.info_value": "info_value",
|
||||
"study.revenue_weight": "revenue_weight",
|
||||
"optimizer.learning_rate": "learning_rate",
|
||||
"optimizer.gamma": "gamma",
|
||||
"optimizer.batch_size": "batch_size",
|
||||
"optimizer.n_steps": "n_steps",
|
||||
"runtime.backend": "backend",
|
||||
"runtime.device": "device",
|
||||
"runtime.seed": "seed",
|
||||
"runtime.total_timesteps": "total_timesteps",
|
||||
"runtime.checkpoint_interval": "checkpoint_interval",
|
||||
"eval.eval_freq": "eval_freq",
|
||||
"eval.eval_episodes": "eval_episodes",
|
||||
}
|
||||
normalized: dict[str, Any] = {}
|
||||
for key, value in raw.items():
|
||||
canonical = alias_map.get(str(key), str(key))
|
||||
normalized[canonical] = value
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgorithmSpec:
|
||||
name: str = "ppo"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnvSpec:
|
||||
n_products: int = 10
|
||||
n_sessions: int = 100
|
||||
price_low: float = 10.0
|
||||
price_high: float = 150.0
|
||||
action_levels: int = 9
|
||||
action_scale_low: float = 0.8
|
||||
action_scale_high: float = 1.2
|
||||
max_steps: int = 100
|
||||
margin_floor: float = 0.05
|
||||
margin_floor_patience: int = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StudySpec:
|
||||
alpha: float = 0.3
|
||||
lambda_coi: float = 0.2
|
||||
robust_radius: float = 0.15
|
||||
robust_points: int = 5
|
||||
info_value: float = 1.0
|
||||
revenue_weight: float = 0.01
|
||||
no_robust: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptimizerSpec:
|
||||
learning_rate: float = 3e-4
|
||||
gamma: float = 0.99
|
||||
buffer_size: int = 50_000
|
||||
batch_size: int = 256
|
||||
tau: float = 0.005
|
||||
train_freq: int = 1
|
||||
learning_starts: int = 1_000
|
||||
target_update_interval: int = 1_000
|
||||
exploration_fraction: float = 0.2
|
||||
exploration_final_eps: float = 0.05
|
||||
n_steps: int = 2_048
|
||||
n_epochs: int = 10
|
||||
gae_lambda: float = 0.95
|
||||
clip_range: float = 0.2
|
||||
ent_coef: float = 0.0
|
||||
q_lr: float = 0.1
|
||||
q_bins: int = 6
|
||||
eps_start: float = 1.0
|
||||
eps_end: float = 0.05
|
||||
eps_decay: float = 0.9995
|
||||
arch: str = "small"
|
||||
activation: str = "relu"
|
||||
jax_num_envs: int = 16
|
||||
jax_num_steps: int = 128
|
||||
jax_num_minibatches: int = 4
|
||||
jax_update_epochs: int = 4
|
||||
jax_anneal_lr: bool = True
|
||||
vf_coef: float = 0.5
|
||||
max_grad_norm: float = 0.5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeSpec:
|
||||
project: str = "capstone"
|
||||
backend: str = "sb3"
|
||||
device: str = "auto"
|
||||
seed: int = 42
|
||||
total_timesteps: int = 50_000
|
||||
checkpoint_interval: int = 200_000
|
||||
model_dir: str = "engine/models"
|
||||
log_freq: int = 100
|
||||
use_jax: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EvalSpec:
|
||||
eval_freq: int = 1_000
|
||||
eval_episodes: int = 5
|
||||
robust_eval_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainSpec:
|
||||
algorithm: AlgorithmSpec = field(default_factory=AlgorithmSpec)
|
||||
env: EnvSpec = field(default_factory=EnvSpec)
|
||||
study: StudySpec = field(default_factory=StudySpec)
|
||||
optimizer: OptimizerSpec = field(default_factory=OptimizerSpec)
|
||||
runtime: RuntimeSpec = field(default_factory=RuntimeSpec)
|
||||
eval: EvalSpec = field(default_factory=EvalSpec)
|
||||
|
||||
def to_flat_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"project": self.runtime.project,
|
||||
"algo": self.algorithm.name,
|
||||
"seed": self.runtime.seed,
|
||||
"total_timesteps": self.runtime.total_timesteps,
|
||||
"eval_episodes": self.eval.eval_episodes,
|
||||
"eval_freq": self.eval.eval_freq,
|
||||
"log_freq": self.runtime.log_freq,
|
||||
"model_dir": self.runtime.model_dir,
|
||||
"backend": self.runtime.backend,
|
||||
"device": self.runtime.device,
|
||||
"use_jax": self.runtime.use_jax,
|
||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||
"n_products": self.env.n_products,
|
||||
"N": self.env.n_sessions,
|
||||
"price_low": self.env.price_low,
|
||||
"price_high": self.env.price_high,
|
||||
"action_levels": self.env.action_levels,
|
||||
"action_scale_low": self.env.action_scale_low,
|
||||
"action_scale_high": self.env.action_scale_high,
|
||||
"max_steps": self.env.max_steps,
|
||||
"margin_floor": self.env.margin_floor,
|
||||
"margin_floor_patience": self.env.margin_floor_patience,
|
||||
"alpha": self.study.alpha,
|
||||
"lambda_coi": self.study.lambda_coi,
|
||||
"robust_radius": self.study.robust_radius,
|
||||
"robust_points": self.study.robust_points,
|
||||
"info_value": self.study.info_value,
|
||||
"revenue_weight": self.study.revenue_weight,
|
||||
"no_robust": self.study.no_robust,
|
||||
"learning_rate": self.optimizer.learning_rate,
|
||||
"gamma": self.optimizer.gamma,
|
||||
"buffer_size": self.optimizer.buffer_size,
|
||||
"batch_size": self.optimizer.batch_size,
|
||||
"tau": self.optimizer.tau,
|
||||
"train_freq": self.optimizer.train_freq,
|
||||
"learning_starts": self.optimizer.learning_starts,
|
||||
"target_update_interval": self.optimizer.target_update_interval,
|
||||
"exploration_fraction": self.optimizer.exploration_fraction,
|
||||
"exploration_final_eps": self.optimizer.exploration_final_eps,
|
||||
"n_steps": self.optimizer.n_steps,
|
||||
"n_epochs": self.optimizer.n_epochs,
|
||||
"gae_lambda": self.optimizer.gae_lambda,
|
||||
"clip_range": self.optimizer.clip_range,
|
||||
"ent_coef": self.optimizer.ent_coef,
|
||||
"q_lr": self.optimizer.q_lr,
|
||||
"q_bins": self.optimizer.q_bins,
|
||||
"eps_start": self.optimizer.eps_start,
|
||||
"eps_end": self.optimizer.eps_end,
|
||||
"eps_decay": self.optimizer.eps_decay,
|
||||
"arch": self.optimizer.arch,
|
||||
"activation": self.optimizer.activation,
|
||||
"jax_num_envs": self.optimizer.jax_num_envs,
|
||||
"jax_num_steps": self.optimizer.jax_num_steps,
|
||||
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
|
||||
"jax_update_epochs": self.optimizer.jax_update_epochs,
|
||||
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
|
||||
"vf_coef": self.optimizer.vf_coef,
|
||||
"max_grad_norm": self.optimizer.max_grad_norm,
|
||||
"robust_eval_enabled": self.eval.robust_eval_enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_flat(
|
||||
cls,
|
||||
raw: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
env_vars: Mapping[str, str] | None = None,
|
||||
) -> "TrainSpec":
|
||||
base = cls().to_flat_dict()
|
||||
incoming = _normalize_keys(raw or {})
|
||||
base.update({k: v for k, v in incoming.items() if v is not None})
|
||||
|
||||
runtime_env = os.environ if env_vars is None else env_vars
|
||||
base["device"] = str(
|
||||
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
|
||||
)
|
||||
|
||||
requested_jax = _truthy(base.get("use_jax")) or _truthy(
|
||||
runtime_env.get("PHANTOM_USE_JAX")
|
||||
)
|
||||
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
|
||||
if backend == "auto":
|
||||
backend = "jax" if requested_jax else "sb3"
|
||||
if backend == "jax":
|
||||
requested_jax = True
|
||||
|
||||
no_robust = _truthy(base.get("no_robust"))
|
||||
if no_robust:
|
||||
base["lambda_coi"] = 0.0
|
||||
base["robust_radius"] = 0.0
|
||||
base["robust_points"] = 1
|
||||
|
||||
return cls(
|
||||
algorithm=AlgorithmSpec(name=str(base["algo"]).lower().strip()),
|
||||
env=EnvSpec(
|
||||
n_products=int(base["n_products"]),
|
||||
n_sessions=int(base["N"]),
|
||||
price_low=float(base["price_low"]),
|
||||
price_high=float(base["price_high"]),
|
||||
action_levels=int(base["action_levels"]),
|
||||
action_scale_low=float(base["action_scale_low"]),
|
||||
action_scale_high=float(base["action_scale_high"]),
|
||||
max_steps=int(base["max_steps"]),
|
||||
margin_floor=float(base["margin_floor"]),
|
||||
margin_floor_patience=int(base["margin_floor_patience"]),
|
||||
),
|
||||
study=StudySpec(
|
||||
alpha=float(base["alpha"]),
|
||||
lambda_coi=float(base["lambda_coi"]),
|
||||
robust_radius=float(base["robust_radius"]),
|
||||
robust_points=int(base["robust_points"]),
|
||||
info_value=float(base["info_value"]),
|
||||
revenue_weight=float(base["revenue_weight"]),
|
||||
no_robust=no_robust,
|
||||
),
|
||||
optimizer=OptimizerSpec(
|
||||
learning_rate=float(base["learning_rate"]),
|
||||
gamma=float(base["gamma"]),
|
||||
buffer_size=int(base["buffer_size"]),
|
||||
batch_size=int(base["batch_size"]),
|
||||
tau=float(base["tau"]),
|
||||
train_freq=int(base["train_freq"]),
|
||||
learning_starts=int(base["learning_starts"]),
|
||||
target_update_interval=int(base["target_update_interval"]),
|
||||
exploration_fraction=float(base["exploration_fraction"]),
|
||||
exploration_final_eps=float(base["exploration_final_eps"]),
|
||||
n_steps=int(base["n_steps"]),
|
||||
n_epochs=int(base["n_epochs"]),
|
||||
gae_lambda=float(base["gae_lambda"]),
|
||||
clip_range=float(base["clip_range"]),
|
||||
ent_coef=float(base["ent_coef"]),
|
||||
q_lr=float(base["q_lr"]),
|
||||
q_bins=int(base["q_bins"]),
|
||||
eps_start=float(base["eps_start"]),
|
||||
eps_end=float(base["eps_end"]),
|
||||
eps_decay=float(base["eps_decay"]),
|
||||
arch=str(base["arch"]),
|
||||
activation=str(base["activation"]),
|
||||
jax_num_envs=int(base["jax_num_envs"]),
|
||||
jax_num_steps=int(base["jax_num_steps"]),
|
||||
jax_num_minibatches=int(base["jax_num_minibatches"]),
|
||||
jax_update_epochs=int(base["jax_update_epochs"]),
|
||||
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
|
||||
vf_coef=float(base["vf_coef"]),
|
||||
max_grad_norm=float(base["max_grad_norm"]),
|
||||
),
|
||||
runtime=RuntimeSpec(
|
||||
project=str(base["project"]),
|
||||
backend=backend,
|
||||
device=str(base["device"]),
|
||||
seed=int(base["seed"]),
|
||||
total_timesteps=int(base["total_timesteps"]),
|
||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||
model_dir=str(base["model_dir"]),
|
||||
log_freq=int(base["log_freq"]),
|
||||
use_jax=requested_jax,
|
||||
),
|
||||
eval=EvalSpec(
|
||||
eval_freq=int(base["eval_freq"]),
|
||||
eval_episodes=int(base["eval_episodes"]),
|
||||
robust_eval_enabled=_truthy(base.get("robust_eval_enabled", True)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
|
||||
return (
|
||||
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
|
||||
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
|
||||
)
|
||||
|
||||
|
||||
def run_metadata(
|
||||
spec: TrainSpec,
|
||||
*,
|
||||
kind: str,
|
||||
scenario: str,
|
||||
group: str | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] = {
|
||||
"run.kind": str(kind),
|
||||
"run.algo": spec.algorithm.name,
|
||||
"run.backend": spec.runtime.backend,
|
||||
"run.device": spec.runtime.device,
|
||||
"run.scenario": str(scenario),
|
||||
"run.seed": spec.runtime.seed,
|
||||
"run.tags": list(tags),
|
||||
}
|
||||
if group:
|
||||
metadata["run.group"] = group
|
||||
return metadata
|
||||
136
engine/studies/local_comparison.py
Normal file
136
engine/studies/local_comparison.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
# Add parent directory to path to allow importing engine
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from engine.wrapper import PHANTOM
|
||||
from engine.lib.wrappers import EconomicMetricsWrapper
|
||||
from engine.lib.providers import (
|
||||
ProviderBenchmark,
|
||||
BenchmarkConfig,
|
||||
RandomBaseline,
|
||||
SurgeBaseline,
|
||||
)
|
||||
|
||||
|
||||
def env_factory(alpha: float):
|
||||
"""Creates a wrapped PHANTOM environment for testing at a specific alpha level."""
|
||||
# Action levels=9 matches the trained PPO model
|
||||
# n_products=8 matches the pretrained model's expectation of Box(16,)
|
||||
env = PHANTOM(
|
||||
n_products=8,
|
||||
alpha=alpha,
|
||||
N=100,
|
||||
action_levels=9,
|
||||
action_scale_low=0.8,
|
||||
action_scale_high=1.2,
|
||||
max_steps=20, # Short episodes so simulation goes fast
|
||||
robust_points=1, # disable expensive adversarial lookaheads
|
||||
render_mode=None,
|
||||
)
|
||||
env = EconomicMetricsWrapper(env)
|
||||
return FlattenObservation(env)
|
||||
|
||||
|
||||
def main():
|
||||
print("Loading pre-trained Robust RL model...")
|
||||
model_path = Path(__file__).parent.parent / "models" / "phantom_ppo.zip"
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
print("Please ensure you have a trained model before running this script.")
|
||||
return
|
||||
|
||||
rl_model = PPO.load(model_path)
|
||||
|
||||
# The action space is Discrete(9). Index 4 is the middle (1.0 scale).
|
||||
n_actions = 9
|
||||
mid_action = n_actions // 2
|
||||
|
||||
providers = {
|
||||
"Static (Base)": lambda obs: mid_action,
|
||||
"Random": RandomBaseline(n_actions),
|
||||
"Heuristic Surge": SurgeBaseline(
|
||||
n_actions, high_threshold=60.0, low_threshold=30.0
|
||||
),
|
||||
"Robust RL (PPO)": lambda obs: rl_model.predict(obs, deterministic=True)[0],
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(
|
||||
n_episodes=10, # Lower episodes to run faster
|
||||
alpha_range=[0.0, 0.5, 1.0], # Fewer alpha levels
|
||||
baseline_name="Static (Base)",
|
||||
)
|
||||
|
||||
print(f"\nStarting benchmark across alpha levels: {config.alpha_range}")
|
||||
print(
|
||||
f"Testing {len(providers)} strategies for {config.n_episodes} episodes each...\n"
|
||||
)
|
||||
|
||||
benchmark = ProviderBenchmark(env_factory, providers, config)
|
||||
results = benchmark.run()
|
||||
|
||||
# 1. Print tabular results
|
||||
df = benchmark.to_dataframe()
|
||||
summary = benchmark.summary_table()
|
||||
print("\n--- Benchmark Summary Table ---")
|
||||
print(summary)
|
||||
|
||||
# 2. Save results to CSV for thesis inclusion
|
||||
out_dir = Path(__file__).parent / "results"
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
csv_path = out_dir / "provider_comparison.csv"
|
||||
df.to_csv(csv_path, index=False)
|
||||
print(f"\nSaved raw results to {csv_path}")
|
||||
|
||||
# 3. Plot the degradation of COI / Revenue as alpha increases
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Plot 1: Revenue vs Alpha
|
||||
plt.subplot(1, 2, 1)
|
||||
for name in providers.keys():
|
||||
provider_data = df[df["name"] == name]
|
||||
plt.plot(
|
||||
provider_data["alpha"],
|
||||
provider_data["mean_revenue"],
|
||||
marker="o",
|
||||
label=name,
|
||||
linewidth=2,
|
||||
)
|
||||
plt.title("Revenue under Agent Contamination")
|
||||
plt.xlabel("Contamination Level (α)")
|
||||
plt.ylabel("Mean Episode Revenue ($)")
|
||||
plt.grid(True, linestyle="--", alpha=0.7)
|
||||
plt.legend()
|
||||
|
||||
# Plot 2: COI Preservation vs Alpha
|
||||
plt.subplot(1, 2, 2)
|
||||
for name in providers.keys():
|
||||
provider_data = df[df["name"] == name]
|
||||
plt.plot(
|
||||
provider_data["alpha"],
|
||||
provider_data["coi_preserved_pct"],
|
||||
marker="s",
|
||||
label=name,
|
||||
linewidth=2,
|
||||
)
|
||||
plt.title("Cost of Information (COI) Preservation")
|
||||
plt.xlabel("Contamination Level (α)")
|
||||
plt.ylabel("COI Preserved (%)")
|
||||
plt.grid(True, linestyle="--", alpha=0.7)
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = out_dir / "alpha_degradation_plot.png"
|
||||
plt.savefig(plot_path, dpi=300)
|
||||
print(f"Saved visualization to {plot_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +1,6 @@
|
||||
method: random
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
method: grid
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
run_cap: 4
|
||||
command:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
method: bayes
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
method: random
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
method: bayes
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
method: bayes
|
||||
metric:
|
||||
name: sweep/score
|
||||
name: objective/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
|
||||
23
engine/telemetry/__init__.py
Normal file
23
engine/telemetry/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .metrics import canonicalize_metrics
|
||||
from .wandb import (
|
||||
current_config,
|
||||
finish_run,
|
||||
get_wandb_module,
|
||||
init_run,
|
||||
log_metrics,
|
||||
run_agent,
|
||||
update_run_config,
|
||||
update_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"canonicalize_metrics",
|
||||
"current_config",
|
||||
"finish_run",
|
||||
"get_wandb_module",
|
||||
"init_run",
|
||||
"log_metrics",
|
||||
"run_agent",
|
||||
"update_run_config",
|
||||
"update_summary",
|
||||
]
|
||||
57
engine/telemetry/metrics.py
Normal file
57
engine/telemetry/metrics.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..spec import TrainSpec
|
||||
|
||||
|
||||
_ALIASES = {
|
||||
"train/reward": "train/reward_mean",
|
||||
"train/revenue": "train/revenue_mean",
|
||||
"train/dqn_loss": "train/loss",
|
||||
"eval/reward": "eval/reward_mean",
|
||||
"eval/revenue": "eval/revenue_mean",
|
||||
"train/steps_per_second": "runtime/steps_per_second",
|
||||
}
|
||||
|
||||
|
||||
def _as_float(value: Any, default: float | None = None) -> float | None:
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, Any]:
|
||||
metrics: dict[str, Any] = {}
|
||||
for key, value in raw.items():
|
||||
canonical = _ALIASES.get(str(key), str(key))
|
||||
if canonical in metrics and canonical != key:
|
||||
continue
|
||||
metrics[canonical] = value
|
||||
|
||||
metrics.setdefault("train/global_step", spec.runtime.total_timesteps)
|
||||
|
||||
eval_reward = _as_float(metrics.get("eval/reward_mean"), 0.0) or 0.0
|
||||
eval_revenue = _as_float(metrics.get("eval/revenue_mean"), 0.0) or 0.0
|
||||
metrics["objective/score"] = eval_reward + spec.study.revenue_weight * eval_revenue
|
||||
|
||||
margin_mean = _as_float(metrics.get("eval/margin_mean"), None)
|
||||
if margin_mean is not None:
|
||||
metrics["objective/constraint_margin"] = margin_mean - spec.env.margin_floor
|
||||
|
||||
coi_level = _as_float(metrics.get("eval/coi_level_mean"), None)
|
||||
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
|
||||
|
||||
metrics["study/alpha"] = spec.study.alpha
|
||||
metrics["study/lambda_coi"] = spec.study.lambda_coi
|
||||
metrics["study/robust_radius"] = spec.study.robust_radius
|
||||
metrics["study/info_value"] = spec.study.info_value
|
||||
|
||||
metrics["runtime/backend"] = spec.runtime.backend
|
||||
metrics["runtime/device"] = spec.runtime.device
|
||||
metrics["runtime/seed"] = spec.runtime.seed
|
||||
|
||||
return metrics
|
||||
98
engine/telemetry/wandb.py
Normal file
98
engine/telemetry/wandb.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Iterable, Mapping
|
||||
|
||||
|
||||
def get_wandb_module():
|
||||
try:
|
||||
import wandb
|
||||
|
||||
return wandb
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _require_wandb():
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None:
|
||||
raise ImportError("wandb is required for this workflow")
|
||||
return wandb
|
||||
|
||||
|
||||
def init_run(
|
||||
*,
|
||||
mode: str,
|
||||
project: str | None = None,
|
||||
config: Mapping[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
tags: Iterable[str] | None = None,
|
||||
group: str | None = None,
|
||||
sweep_mode: bool = False,
|
||||
):
|
||||
wandb = _require_wandb()
|
||||
kwargs: dict[str, Any] = {"mode": mode}
|
||||
if group:
|
||||
kwargs["group"] = group
|
||||
if sweep_mode:
|
||||
run = wandb.init(**kwargs)
|
||||
if name and run is not None:
|
||||
run.name = name
|
||||
return run
|
||||
|
||||
init_kwargs = dict(kwargs)
|
||||
init_kwargs["project"] = project
|
||||
if config is not None:
|
||||
init_kwargs["config"] = dict(config)
|
||||
if name:
|
||||
init_kwargs["name"] = name
|
||||
if tags:
|
||||
init_kwargs["tags"] = list(tags)
|
||||
return wandb.init(**init_kwargs)
|
||||
|
||||
|
||||
def finish_run() -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def current_config() -> dict[str, Any]:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return {}
|
||||
return {key: wandb.config[key] for key in wandb.config.keys()}
|
||||
|
||||
|
||||
def update_run_config(config: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
try:
|
||||
wandb.config.update(dict(config), allow_val_change=True)
|
||||
except TypeError:
|
||||
wandb.config.update(dict(config))
|
||||
|
||||
|
||||
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
wandb.log(dict(metrics), step=step)
|
||||
|
||||
|
||||
def update_summary(metrics: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
for key, value in metrics.items():
|
||||
wandb.run.summary[key] = value
|
||||
|
||||
|
||||
def run_agent(
|
||||
sweep_id: str,
|
||||
fn: Callable[[], None],
|
||||
*,
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
wandb = _require_wandb()
|
||||
wandb.agent(sweep_id, function=fn, count=count)
|
||||
722
engine/train.py
722
engine/train.py
@@ -1,98 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
|
||||
try:
|
||||
import wandb as _wandb
|
||||
|
||||
if hasattr(_wandb, "init") and callable(_wandb.init):
|
||||
wandb = _wandb
|
||||
HAS_WANDB = True
|
||||
else:
|
||||
wandb = None
|
||||
HAS_WANDB = False
|
||||
except ImportError:
|
||||
wandb = None
|
||||
HAS_WANDB = False
|
||||
|
||||
try:
|
||||
from stable_baselines3 import PPO, A2C, DQN
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
||||
HAS_SB3 = True
|
||||
except ImportError:
|
||||
HAS_SB3 = False
|
||||
|
||||
from .jax import JAX_AVAILABLE
|
||||
|
||||
|
||||
DEFAULT_CFG = {
|
||||
"project": "phantom-pricing",
|
||||
"algo": "ppo",
|
||||
"seed": 42,
|
||||
"total_timesteps": 50_000,
|
||||
"eval_episodes": 5,
|
||||
"eval_freq": 1_000,
|
||||
"log_freq": 100,
|
||||
"revenue_weight": 0.01,
|
||||
"n_products": 10,
|
||||
"N": 100,
|
||||
"alpha": 0.3,
|
||||
"lambda_coi": 0.2,
|
||||
"robust_radius": 0.15,
|
||||
"robust_points": 5,
|
||||
"no_robust": False,
|
||||
"info_value": 1.0,
|
||||
"price_low": 10.0,
|
||||
"price_high": 150.0,
|
||||
"action_levels": 9,
|
||||
"action_scale_low": 0.8,
|
||||
"action_scale_high": 1.2,
|
||||
"learning_rate": 3e-4,
|
||||
"gamma": 0.99,
|
||||
"buffer_size": 50_000,
|
||||
"batch_size": 256,
|
||||
"tau": 0.005,
|
||||
"train_freq": 1,
|
||||
"learning_starts": 1_000,
|
||||
"target_update_interval": 1_000,
|
||||
"exploration_fraction": 0.2,
|
||||
"exploration_final_eps": 0.05,
|
||||
"n_steps": 2_048,
|
||||
"n_epochs": 10,
|
||||
"gae_lambda": 0.95,
|
||||
"clip_range": 0.2,
|
||||
"ent_coef": 0.0,
|
||||
"q_lr": 0.1,
|
||||
"eps_start": 1.0,
|
||||
"eps_end": 0.05,
|
||||
"eps_decay": 0.9995,
|
||||
"model_dir": "engine/models",
|
||||
"arch": "small",
|
||||
"activation": "relu",
|
||||
"q_bins": 6,
|
||||
"max_steps": 100,
|
||||
"margin_floor": 0.05,
|
||||
"margin_floor_patience": 5,
|
||||
"use_jax": False,
|
||||
"jax_num_envs": 16,
|
||||
"jax_num_steps": 128,
|
||||
"jax_num_minibatches": 4,
|
||||
"jax_update_epochs": 4,
|
||||
"jax_anneal_lr": True,
|
||||
"checkpoint_interval": 200_000,
|
||||
}
|
||||
from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
|
||||
from .spec import TrainSpec
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
@@ -103,423 +15,133 @@ def _truthy(value: str | bool | None) -> bool:
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _cfg(raw: dict | None = None) -> dict:
|
||||
cfg = dict(DEFAULT_CFG)
|
||||
if raw:
|
||||
cfg.update({k: v for k, v in raw.items() if v is not None})
|
||||
cfg["algo"] = str(cfg["algo"]).lower()
|
||||
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
|
||||
os.environ.get("PHANTOM_USE_JAX")
|
||||
def _parse_tags(raw: str | None) -> list[str]:
|
||||
if raw is None:
|
||||
return []
|
||||
return [piece.strip() for piece in str(raw).split(",") if piece.strip()]
|
||||
|
||||
|
||||
def _probe_run_kind(argv: list[str]) -> str:
|
||||
probe = argparse.ArgumentParser(add_help=False)
|
||||
probe.add_argument("--run-kind", choices=["train", "benchmark"])
|
||||
probe.add_argument("--run-mode", choices=["train", "benchmark"])
|
||||
args, _ = probe.parse_known_args(argv)
|
||||
return str(args.run_kind or args.run_mode or "train")
|
||||
|
||||
|
||||
def _strip_run_kind(argv: list[str]) -> list[str]:
|
||||
stripped: list[str] = []
|
||||
skip_next = False
|
||||
for item in argv:
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if item in {"--run-kind", "--run-mode"}:
|
||||
skip_next = True
|
||||
continue
|
||||
if item.startswith("--run-kind=") or item.startswith("--run-mode="):
|
||||
continue
|
||||
stripped.append(item)
|
||||
return stripped
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="PHANTOM unified training entrypoint")
|
||||
parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train")
|
||||
parser.add_argument("--run-mode", choices=["train", "benchmark"])
|
||||
|
||||
parser.add_argument("--project", default="capstone")
|
||||
parser.add_argument("--scenario", default="default")
|
||||
parser.add_argument("--group", type=str)
|
||||
parser.add_argument("--tags", type=str)
|
||||
|
||||
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto")
|
||||
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--total-timesteps", type=int)
|
||||
parser.add_argument("--model-dir", type=str)
|
||||
parser.add_argument("--log-freq", type=int)
|
||||
parser.add_argument("--checkpoint-interval", type=int)
|
||||
parser.add_argument("--device", type=str)
|
||||
|
||||
parser.add_argument("--alpha", type=float)
|
||||
parser.add_argument("--N", type=int)
|
||||
parser.add_argument("--n-products", type=int)
|
||||
parser.add_argument("--lambda-coi", type=float)
|
||||
parser.add_argument("--info-value", type=float)
|
||||
parser.add_argument("--robust-radius", type=float)
|
||||
parser.add_argument("--robust-points", type=int)
|
||||
parser.add_argument("--no-robust", action="store_true")
|
||||
parser.add_argument("--revenue-weight", type=float)
|
||||
|
||||
parser.add_argument("--price-low", type=float)
|
||||
parser.add_argument("--price-high", type=float)
|
||||
parser.add_argument("--action-levels", type=int)
|
||||
parser.add_argument("--action-scale-low", type=float)
|
||||
parser.add_argument("--action-scale-high", type=float)
|
||||
parser.add_argument("--max-steps", type=int)
|
||||
parser.add_argument("--margin-floor", type=float)
|
||||
parser.add_argument("--margin-floor-patience", type=int)
|
||||
|
||||
parser.add_argument("--learning-rate", type=float)
|
||||
parser.add_argument("--gamma", type=float)
|
||||
parser.add_argument("--buffer-size", type=int)
|
||||
parser.add_argument("--batch-size", type=int)
|
||||
parser.add_argument("--tau", type=float)
|
||||
parser.add_argument("--train-freq", type=int)
|
||||
parser.add_argument("--learning-starts", type=int)
|
||||
parser.add_argument("--target-update-interval", type=int)
|
||||
parser.add_argument("--exploration-fraction", type=float)
|
||||
parser.add_argument("--exploration-final-eps", type=float)
|
||||
parser.add_argument("--n-steps", type=int)
|
||||
parser.add_argument("--n-epochs", type=int)
|
||||
parser.add_argument("--gae-lambda", type=float)
|
||||
parser.add_argument("--clip-range", type=float)
|
||||
parser.add_argument("--ent-coef", type=float)
|
||||
parser.add_argument("--q-lr", type=float)
|
||||
parser.add_argument("--q-bins", type=int)
|
||||
parser.add_argument("--eps-start", type=float)
|
||||
parser.add_argument("--eps-end", type=float)
|
||||
parser.add_argument("--eps-decay", type=float)
|
||||
parser.add_argument("--arch", type=str)
|
||||
parser.add_argument("--activation", type=str)
|
||||
parser.add_argument("--vf-coef", type=float)
|
||||
parser.add_argument("--max-grad-norm", type=float)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int)
|
||||
parser.add_argument("--eval-episodes", type=int)
|
||||
|
||||
parser.add_argument("--jax", action="store_true")
|
||||
parser.add_argument("--jax-num-envs", type=int)
|
||||
parser.add_argument("--jax-num-steps", type=int)
|
||||
parser.add_argument("--jax-num-minibatches", type=int)
|
||||
parser.add_argument("--jax-update-epochs", type=int)
|
||||
parser.add_argument("--jax-anneal-lr", type=str)
|
||||
|
||||
parser.add_argument("--sweep-agent", action="store_true")
|
||||
parser.add_argument("--sweep-id", type=str)
|
||||
parser.add_argument("--count", type=int, default=0)
|
||||
parser.add_argument("--offline", action="store_true")
|
||||
parser.add_argument("--no-wandb", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||
jax_anneal_lr = (
|
||||
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
|
||||
)
|
||||
cfg["no_robust"] = _truthy(cfg.get("no_robust"))
|
||||
if cfg["no_robust"]:
|
||||
cfg["lambda_coi"] = 0.0
|
||||
cfg["robust_radius"] = 0.0
|
||||
cfg["robust_points"] = 1
|
||||
return cfg
|
||||
|
||||
|
||||
def _wandb_cfg_dict() -> dict:
|
||||
return (
|
||||
{k: wandb.config[k] for k in wandb.config.keys()}
|
||||
if HAS_WANDB and wandb.run
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
def make_env(cfg: dict):
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from .wrapper import PHANTOM
|
||||
from .lib.wrappers import EconomicMetricsWrapper
|
||||
|
||||
env = PHANTOM(
|
||||
n_products=int(cfg["n_products"]),
|
||||
alpha=float(cfg["alpha"]),
|
||||
N=int(cfg["N"]),
|
||||
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
|
||||
lambda_coi=float(cfg["lambda_coi"]),
|
||||
robust_radius=float(cfg["robust_radius"]),
|
||||
robust_points=int(cfg["robust_points"]),
|
||||
info_value=float(cfg["info_value"]),
|
||||
action_levels=int(cfg["action_levels"]),
|
||||
action_scale_low=float(cfg["action_scale_low"]),
|
||||
action_scale_high=float(cfg["action_scale_high"]),
|
||||
max_steps=int(cfg.get("max_steps", 100)),
|
||||
margin_floor=float(cfg.get("margin_floor", 0.05)),
|
||||
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
|
||||
render_mode=None,
|
||||
)
|
||||
env = EconomicMetricsWrapper(env)
|
||||
env = FlattenObservation(env)
|
||||
return env
|
||||
|
||||
|
||||
def _net_arch(name) -> list[int]:
|
||||
presets = {
|
||||
"tiny": [32, 32],
|
||||
"small": [64, 64],
|
||||
"medium": [128, 128],
|
||||
"large": [256, 256],
|
||||
}
|
||||
if isinstance(name, (list, tuple)):
|
||||
return [int(v) for v in name]
|
||||
s = str(name).lower().strip()
|
||||
if s in presets:
|
||||
return presets[s]
|
||||
if "x" in s:
|
||||
try:
|
||||
vals = [int(v) for v in s.split("x") if v]
|
||||
return vals if vals else presets["small"]
|
||||
except ValueError:
|
||||
return presets["small"]
|
||||
return presets["small"]
|
||||
|
||||
|
||||
def _activation(name):
|
||||
try:
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
return None
|
||||
return {
|
||||
"relu": nn.ReLU,
|
||||
"tanh": nn.Tanh,
|
||||
"elu": nn.ELU,
|
||||
"leaky_relu": nn.LeakyReLU,
|
||||
}.get(str(name).lower().strip(), nn.ReLU)
|
||||
|
||||
|
||||
def _policy_kwargs(cfg: dict) -> dict:
|
||||
kw = {"net_arch": _net_arch(cfg.get("arch", "small"))}
|
||||
act = _activation(cfg.get("activation", "relu"))
|
||||
if act is not None:
|
||||
kw["activation_fn"] = act
|
||||
return kw
|
||||
|
||||
|
||||
def _action(agent, obs, deterministic: bool = True):
|
||||
out = agent.predict(obs, deterministic=deterministic)
|
||||
a = out[0] if isinstance(out, tuple) else out
|
||||
if isinstance(a, np.ndarray) and a.size == 1:
|
||||
return int(a.reshape(-1)[0])
|
||||
return a
|
||||
|
||||
|
||||
def evaluate(agent, env, episodes: int) -> dict:
|
||||
rewards, revenues = [], []
|
||||
for _ in range(int(episodes)):
|
||||
obs, _ = env.reset()
|
||||
done, ep_r, ep_rev = False, 0.0, 0.0
|
||||
while not done:
|
||||
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
|
||||
done = term or trunc
|
||||
ep_r += float(reward)
|
||||
ep_rev += float(
|
||||
info.get("economics", {}).get("revenue", info.get("revenue", 0.0))
|
||||
)
|
||||
rewards.append(ep_r)
|
||||
revenues.append(ep_rev)
|
||||
return {
|
||||
"eval/reward": float(np.mean(rewards)),
|
||||
"eval/revenue": float(np.mean(revenues)),
|
||||
"eval/reward_std": float(np.std(rewards)),
|
||||
"eval/revenue_std": float(np.std(revenues)),
|
||||
}
|
||||
|
||||
|
||||
def build_model(cfg: dict, env):
|
||||
algo = cfg["algo"]
|
||||
policy_kwargs = _policy_kwargs(cfg)
|
||||
if algo == "sac":
|
||||
raise ValueError("sac is not supported with the discrete core env")
|
||||
if algo == "ppo":
|
||||
return PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=int(cfg["seed"]),
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
n_steps=int(cfg["n_steps"]),
|
||||
batch_size=int(cfg["batch_size"]),
|
||||
n_epochs=int(cfg["n_epochs"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
gae_lambda=float(cfg["gae_lambda"]),
|
||||
clip_range=float(cfg["clip_range"]),
|
||||
ent_coef=float(cfg["ent_coef"]),
|
||||
)
|
||||
if algo == "a2c":
|
||||
return A2C(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=int(cfg["seed"]),
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
n_steps=max(5, int(cfg["n_steps"]) // 32),
|
||||
gamma=float(cfg["gamma"]),
|
||||
gae_lambda=float(cfg["gae_lambda"]),
|
||||
ent_coef=float(cfg["ent_coef"]),
|
||||
)
|
||||
if algo == "dqn":
|
||||
return DQN(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
policy_kwargs=policy_kwargs,
|
||||
seed=int(cfg["seed"]),
|
||||
learning_rate=float(cfg["learning_rate"]),
|
||||
buffer_size=int(cfg["buffer_size"]),
|
||||
batch_size=int(cfg["batch_size"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
train_freq=int(cfg["train_freq"]),
|
||||
learning_starts=int(cfg["learning_starts"]),
|
||||
target_update_interval=int(cfg["target_update_interval"]),
|
||||
exploration_fraction=float(cfg["exploration_fraction"]),
|
||||
exploration_final_eps=float(cfg["exploration_final_eps"]),
|
||||
)
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def _sb3_model_cls(algo: str):
|
||||
if algo == "ppo":
|
||||
return PPO
|
||||
if algo == "a2c":
|
||||
return A2C
|
||||
if algo == "dqn":
|
||||
return DQN
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def train_qtable(cfg: dict) -> tuple["EventQTable", dict]:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
np.random.seed(int(cfg["seed"]))
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
agent = EventQTable(
|
||||
env.action_space.n,
|
||||
int(cfg["n_products"]),
|
||||
(float(cfg["price_low"]), float(cfg["price_high"])),
|
||||
lr=float(cfg["q_lr"]),
|
||||
gamma=float(cfg["gamma"]),
|
||||
n_bins=int(cfg["q_bins"]),
|
||||
)
|
||||
eps = float(cfg["eps_start"])
|
||||
obs, _ = env.reset(seed=int(cfg["seed"]))
|
||||
for t in range(int(cfg["total_timesteps"])):
|
||||
a, s = agent.act(obs, eps)
|
||||
nxt, reward, term, trunc, info = env.step(a)
|
||||
done = term or trunc
|
||||
agent.update(s, a, float(reward), agent.encode(nxt), done)
|
||||
eps = max(float(cfg["eps_end"]), eps * float(cfg["eps_decay"]))
|
||||
if HAS_WANDB and wandb.run and (t + 1) % int(cfg["log_freq"]) == 0:
|
||||
econ = info.get("economics", {})
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": float(reward),
|
||||
"train/revenue": float(econ.get("revenue", 0.0)),
|
||||
"train/epsilon": float(eps),
|
||||
},
|
||||
step=t + 1,
|
||||
)
|
||||
obs = env.reset()[0] if done else nxt
|
||||
metrics = evaluate(agent, eval_env, int(cfg["eval_episodes"]))
|
||||
metrics["train/global_step"] = int(cfg["total_timesteps"])
|
||||
env.close()
|
||||
eval_env.close()
|
||||
return agent, metrics
|
||||
|
||||
|
||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
if not HAS_SB3:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
env = Monitor(env)
|
||||
eval_env = Monitor(eval_env)
|
||||
model = build_model(cfg, env)
|
||||
resume_step = 0
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||
if restored is not None:
|
||||
checkpoint_path, metadata = restored
|
||||
model = _sb3_model_cls(cfg["algo"]).load(
|
||||
checkpoint_path.as_posix(), env=env
|
||||
)
|
||||
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
|
||||
model.num_timesteps = max(
|
||||
int(getattr(model, "num_timesteps", 0)), resume_step
|
||||
)
|
||||
|
||||
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
|
||||
cbs.append(
|
||||
CheckpointArtifactCallback(
|
||||
cfg,
|
||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||
)
|
||||
)
|
||||
cbs.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
eval_freq=int(cfg["eval_freq"]),
|
||||
n_eval_episodes=int(cfg["eval_episodes"]),
|
||||
deterministic=True,
|
||||
verbose=0,
|
||||
)
|
||||
)
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
if remaining_steps > 0:
|
||||
model.learn(
|
||||
total_timesteps=remaining_steps,
|
||||
callback=cbs,
|
||||
reset_num_timesteps=False,
|
||||
)
|
||||
|
||||
model_path = Path(cfg["model_dir"])
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save(str(model_path / f"phantom_{cfg['algo']}"))
|
||||
metrics = evaluate(model, eval_env, int(cfg["eval_episodes"]))
|
||||
metrics["train/global_step"] = int(model.num_timesteps)
|
||||
env.close()
|
||||
eval_env.close()
|
||||
return model, metrics
|
||||
|
||||
|
||||
def train_once(cfg: dict) -> dict:
|
||||
algo = cfg["algo"]
|
||||
if cfg.get("use_jax"):
|
||||
if not JAX_AVAILABLE:
|
||||
raise ImportError(
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
try:
|
||||
from .jax.train import train_jax
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
|
||||
_, metrics = train_jax(cfg)
|
||||
elif algo == "qtable":
|
||||
_, metrics = train_qtable(cfg)
|
||||
else:
|
||||
_, metrics = train_sb3(cfg)
|
||||
metrics["sweep/score"] = float(
|
||||
metrics["eval/reward"] + float(cfg["revenue_weight"]) * metrics["eval/revenue"]
|
||||
)
|
||||
return metrics
|
||||
|
||||
|
||||
def run_wandb(
|
||||
project: str, overrides: dict, mode: str = "online", sweep_mode: bool = False
|
||||
) -> dict:
|
||||
if not HAS_WANDB:
|
||||
raise ImportError("wandb is required for sweep runs")
|
||||
if not sweep_mode:
|
||||
pre_cfg = _cfg(overrides)
|
||||
if pre_cfg.get("use_jax"):
|
||||
try:
|
||||
import jax
|
||||
|
||||
if jax.process_count() > 1 and jax.process_index() != 0:
|
||||
return train_once(pre_cfg)
|
||||
except Exception:
|
||||
pass
|
||||
init_kwargs = {"mode": mode}
|
||||
if sweep_mode:
|
||||
run = wandb.init(**init_kwargs)
|
||||
else:
|
||||
run = wandb.init(project=project, config=overrides, **init_kwargs)
|
||||
|
||||
try:
|
||||
cfg = _cfg(_wandb_cfg_dict())
|
||||
if sweep_mode:
|
||||
for k, v in overrides.items():
|
||||
if k not in wandb.config:
|
||||
cfg[k] = v
|
||||
|
||||
metrics = train_once(cfg)
|
||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||
wandb.log(metrics, step=step)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
return metrics
|
||||
finally:
|
||||
if wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def run_local(overrides: dict) -> dict:
|
||||
cfg = _cfg(overrides)
|
||||
metrics = train_once(cfg)
|
||||
should_print = True
|
||||
if cfg.get("use_jax"):
|
||||
try:
|
||||
import jax
|
||||
|
||||
should_print = jax.process_index() == 0
|
||||
except Exception:
|
||||
should_print = True
|
||||
if should_print:
|
||||
print(json.dumps(metrics, indent=2))
|
||||
# sentinel line for machine-readable extraction; must stay on one line
|
||||
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
|
||||
p.add_argument("--project", default=DEFAULT_CFG["project"])
|
||||
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
|
||||
p.add_argument("--seed", type=int)
|
||||
p.add_argument("--total-timesteps", type=int)
|
||||
p.add_argument("--alpha", type=float)
|
||||
p.add_argument("--N", type=int)
|
||||
p.add_argument("--n-products", type=int)
|
||||
p.add_argument("--lambda-coi", type=float)
|
||||
p.add_argument("--info-value", type=float)
|
||||
p.add_argument("--robust-radius", type=float)
|
||||
p.add_argument("--robust-points", type=int)
|
||||
p.add_argument("--no-robust", action="store_true")
|
||||
p.add_argument("--learning-rate", type=float)
|
||||
p.add_argument("--gamma", type=float)
|
||||
p.add_argument("--gae-lambda", type=float)
|
||||
p.add_argument("--clip-range", type=float)
|
||||
p.add_argument("--ent-coef", type=float)
|
||||
p.add_argument("--revenue-weight", type=float)
|
||||
p.add_argument("--price-low", type=float)
|
||||
p.add_argument("--price-high", type=float)
|
||||
p.add_argument("--action-levels", type=int)
|
||||
p.add_argument("--action-scale-low", type=float)
|
||||
p.add_argument("--action-scale-high", type=float)
|
||||
p.add_argument("--max-steps", type=int)
|
||||
p.add_argument("--margin-floor", type=float)
|
||||
p.add_argument("--margin-floor-patience", type=int)
|
||||
p.add_argument("--arch", type=str)
|
||||
p.add_argument("--activation", type=str)
|
||||
p.add_argument("--jax", action="store_true")
|
||||
p.add_argument("--jax-num-envs", type=int)
|
||||
p.add_argument("--jax-num-steps", type=int)
|
||||
p.add_argument("--jax-num-minibatches", type=int)
|
||||
p.add_argument("--jax-update-epochs", type=int)
|
||||
p.add_argument("--jax-anneal-lr", type=str)
|
||||
p.add_argument("--checkpoint-interval", type=int)
|
||||
p.add_argument("--sweep-agent", action="store_true")
|
||||
p.add_argument("--sweep-id", type=str)
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
p.add_argument("--offline", action="store_true")
|
||||
p.add_argument("--no-wandb", action="store_true")
|
||||
args = p.parse_args()
|
||||
backend = None if args.backend == "auto" else args.backend
|
||||
|
||||
overrides = {
|
||||
"project": args.project,
|
||||
"backend": backend,
|
||||
"algo": args.algo,
|
||||
"seed": args.seed,
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"model_dir": args.model_dir,
|
||||
"log_freq": args.log_freq,
|
||||
"checkpoint_interval": args.checkpoint_interval,
|
||||
"device": args.device,
|
||||
"alpha": args.alpha,
|
||||
"N": args.N,
|
||||
"n_products": args.n_products,
|
||||
@@ -528,11 +150,6 @@ def main():
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"no_robust": args.no_robust,
|
||||
"learning_rate": args.learning_rate,
|
||||
"gamma": args.gamma,
|
||||
"gae_lambda": args.gae_lambda,
|
||||
"clip_range": args.clip_range,
|
||||
"ent_coef": args.ent_coef,
|
||||
"revenue_weight": args.revenue_weight,
|
||||
"price_low": args.price_low,
|
||||
"price_high": args.price_high,
|
||||
@@ -542,40 +159,87 @@ def main():
|
||||
"max_steps": args.max_steps,
|
||||
"margin_floor": args.margin_floor,
|
||||
"margin_floor_patience": args.margin_floor_patience,
|
||||
"learning_rate": args.learning_rate,
|
||||
"gamma": args.gamma,
|
||||
"buffer_size": args.buffer_size,
|
||||
"batch_size": args.batch_size,
|
||||
"tau": args.tau,
|
||||
"train_freq": args.train_freq,
|
||||
"learning_starts": args.learning_starts,
|
||||
"target_update_interval": args.target_update_interval,
|
||||
"exploration_fraction": args.exploration_fraction,
|
||||
"exploration_final_eps": args.exploration_final_eps,
|
||||
"n_steps": args.n_steps,
|
||||
"n_epochs": args.n_epochs,
|
||||
"gae_lambda": args.gae_lambda,
|
||||
"clip_range": args.clip_range,
|
||||
"ent_coef": args.ent_coef,
|
||||
"q_lr": args.q_lr,
|
||||
"q_bins": args.q_bins,
|
||||
"eps_start": args.eps_start,
|
||||
"eps_end": args.eps_end,
|
||||
"eps_decay": args.eps_decay,
|
||||
"arch": args.arch,
|
||||
"activation": args.activation,
|
||||
"use_jax": args.jax,
|
||||
"vf_coef": args.vf_coef,
|
||||
"max_grad_norm": args.max_grad_norm,
|
||||
"eval_freq": args.eval_freq,
|
||||
"eval_episodes": args.eval_episodes,
|
||||
"use_jax": args.jax or None,
|
||||
"jax_num_envs": args.jax_num_envs,
|
||||
"jax_num_steps": args.jax_num_steps,
|
||||
"jax_num_minibatches": args.jax_num_minibatches,
|
||||
"jax_update_epochs": args.jax_update_epochs,
|
||||
"checkpoint_interval": args.checkpoint_interval,
|
||||
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||
if args.jax_anneal_lr is not None
|
||||
else None,
|
||||
"jax_anneal_lr": jax_anneal_lr,
|
||||
}
|
||||
overrides = {k: v for k, v in overrides.items() if v is not None}
|
||||
return {key: value for key, value in overrides.items() if value is not None}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> None:
|
||||
import sys
|
||||
|
||||
raw_args = list(sys.argv[1:] if argv is None else argv)
|
||||
run_kind = _probe_run_kind(raw_args)
|
||||
if run_kind == "benchmark":
|
||||
run_benchmark_cli(_strip_run_kind(raw_args))
|
||||
return
|
||||
|
||||
parser = _build_parser()
|
||||
args, unknown = parser.parse_known_args(raw_args)
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown arguments for training mode: {' '.join(unknown)}")
|
||||
|
||||
overrides = _overrides_from_args(args)
|
||||
scenario = str(args.scenario)
|
||||
group = args.group
|
||||
extra_tags = tuple(_parse_tags(args.tags))
|
||||
|
||||
if args.sweep_agent:
|
||||
if args.no_wandb:
|
||||
raise ValueError("sweep agent requires wandb")
|
||||
if not args.sweep_id:
|
||||
raise ValueError("--sweep-id is required with --sweep-agent")
|
||||
mode = "offline" if args.offline else "online"
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
function=lambda: run_wandb(
|
||||
args.project, overrides, mode=mode, sweep_mode=True
|
||||
),
|
||||
count=args.count if args.count > 0 else None,
|
||||
run_sweep_agent(
|
||||
project=args.project,
|
||||
sweep_id=str(args.sweep_id or ""),
|
||||
count=int(args.count),
|
||||
offline=bool(args.offline),
|
||||
no_wandb=bool(args.no_wandb),
|
||||
base_overrides=overrides,
|
||||
kind="sweep",
|
||||
scenario=scenario,
|
||||
group=group,
|
||||
extra_tags=extra_tags,
|
||||
)
|
||||
return
|
||||
|
||||
if args.no_wandb or not HAS_WANDB:
|
||||
run_local(overrides)
|
||||
return
|
||||
|
||||
run_wandb(args.project, overrides, mode="offline" if args.offline else "online")
|
||||
spec = TrainSpec.from_flat(overrides)
|
||||
run_train_once(
|
||||
spec,
|
||||
project=args.project,
|
||||
offline=bool(args.offline),
|
||||
no_wandb=bool(args.no_wandb),
|
||||
kind="train",
|
||||
scenario=scenario,
|
||||
group=group,
|
||||
extra_tags=extra_tags,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
40
engine/train_core.py
Normal file
40
engine/train_core.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .spec import TrainSpec
|
||||
from .telemetry.metrics import canonicalize_metrics
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainResult:
|
||||
spec: TrainSpec
|
||||
metrics: dict[str, Any]
|
||||
artifacts: dict[str, str]
|
||||
|
||||
|
||||
def run_train(spec: TrainSpec) -> TrainResult:
|
||||
cfg = spec.to_flat_dict()
|
||||
algo = spec.algorithm.name
|
||||
|
||||
if spec.runtime.use_jax or spec.runtime.backend == "jax":
|
||||
from .backends.jax import train_jax_backend
|
||||
|
||||
_, raw_metrics = train_jax_backend(cfg)
|
||||
elif algo == "qtable":
|
||||
from .backends.qtable import train_qtable
|
||||
|
||||
_, raw_metrics = train_qtable(cfg)
|
||||
else:
|
||||
from .backends.sb3 import train_sb3
|
||||
|
||||
_, raw_metrics = train_sb3(cfg)
|
||||
|
||||
metrics = canonicalize_metrics(raw_metrics, spec)
|
||||
artifacts: dict[str, str] = {}
|
||||
model_path = raw_metrics.get("model/path")
|
||||
if isinstance(model_path, str):
|
||||
artifacts["model/path"] = model_path
|
||||
|
||||
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)
|
||||
@@ -381,7 +381,7 @@ if __name__ == "__main__":
|
||||
def predict(self, obs, **kwargs):
|
||||
return self.env.action_space.sample(), None
|
||||
|
||||
wandb.init(project="phantom-pricing", config={"policy": "random", "alpha": 0.3})
|
||||
wandb.init(project="capstone", config={"policy": "random", "alpha": 0.3})
|
||||
env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None))
|
||||
|
||||
model = RandomPolicy(env)
|
||||
|
||||
3
nx.json
3
nx.json
@@ -55,6 +55,9 @@
|
||||
"train": {
|
||||
"cache": false
|
||||
},
|
||||
"benchmark": {
|
||||
"cache": false
|
||||
},
|
||||
"up": {
|
||||
"cache": false
|
||||
},
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
"platform:down": "nx run platform:down",
|
||||
"platform:logs": "nx run platform:logs",
|
||||
"research:test": "nx run research:test",
|
||||
"research:benchmark": "nx run research:benchmark",
|
||||
"e2e:test": "nx run e2e:test"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -30,10 +30,20 @@ case "$cmd" in
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000}
|
||||
;;
|
||||
benchmark)
|
||||
load_sweep_env
|
||||
if [[ " ${LOCAL_BENCHMARK_ARGS:-} " != *" --no-wandb "* ]]; then
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="${WANDB_API_KEY:-}" \
|
||||
.venv/bin/python -m engine.train --run-kind benchmark ${LOCAL_BENCHMARK_ARGS:---tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu}
|
||||
;;
|
||||
train-agent)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
@@ -43,10 +53,23 @@ case "$cmd" in
|
||||
args+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train "${args[@]}"
|
||||
;;
|
||||
benchmark-agent)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
||||
args=(--sweep-agent --sweep-id "$SWEEP_ID")
|
||||
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
|
||||
args+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train --run-kind benchmark "${args[@]}" ${BENCHMARK_AGENT_ARGS:-}
|
||||
;;
|
||||
train-bootstrap)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
@@ -55,7 +78,7 @@ case "$cmd" in
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
GITHUB_TOKEN="$GITHUB_TOKEN" \
|
||||
REPO_URL="$REPO_URL" \
|
||||
BRANCH="${BRANCH:-main}" \
|
||||
@@ -115,7 +138,7 @@ PY
|
||||
train-tpu-vm-sweep)
|
||||
load_sweep_env
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
args=(
|
||||
--sweep-id "$SWEEP_ID"
|
||||
|
||||
@@ -96,7 +96,11 @@ def _extract_metrics(output: str) -> dict:
|
||||
obj = json.loads(block)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
|
||||
if isinstance(obj, dict) and (
|
||||
"objective/score" in obj
|
||||
or "eval/reward_mean" in obj
|
||||
or "sweep/score" in obj
|
||||
):
|
||||
return obj
|
||||
return {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user