updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -132,15 +132,15 @@ def evaluate(
shifted_env.close()
shifted_rows.append((tag, alpha, shifted_metrics))
metrics["eval/robust_alpha_low"] = low_alpha
metrics["eval/robust_alpha_high"] = high_alpha
metrics["eval/robust_reward_worst"] = float(
metrics["eval/stress_alpha_low"] = low_alpha
metrics["eval/stress_alpha_high"] = high_alpha
metrics["eval/stress_reward_worst"] = float(
min(row[2]["eval/reward_mean"] for row in shifted_rows)
)
metrics["eval/robust_revenue_worst"] = float(
metrics["eval/stress_revenue_worst"] = float(
min(row[2]["eval/revenue_mean"] for row in shifted_rows)
)
metrics["eval/robust_coi_leakage_worst"] = float(
metrics["eval/stress_coi_leakage_worst"] = float(
max(row[2]["eval/coi_leakage_mean"] for row in shifted_rows)
)
for tag, alpha, shifted_metrics in shifted_rows:

View File

@@ -80,7 +80,11 @@ def train_qtable(
"train/global_step": int(steps),
}
if wandb_live:
try:
wandb.log(dict(event), step=step_offset + int(steps))
except Exception:
wandb_live = False
train_events.append(event)
else:
train_events.append(event)
if console_progress:
@@ -113,7 +117,11 @@ def train_qtable(
"train/global_step": int(steps),
}
if wandb_live:
try:
wandb.log(dict(tail_event), step=step_offset + int(steps))
except Exception:
wandb_live = False
train_events.append(tail_event)
else:
train_events.append(tail_event)

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import MetricsCallback
from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
from .common import evaluate, make_env
@@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any):
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc
@@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
pass
metrics_callback = MetricsCallback(
log_histograms=False,
log_histograms=True,
log_freq=int(cfg["log_freq"]),
hist_freq=int(cfg.get("hist_freq", 500)),
step_offset=int(cfg.get("wandb_step_offset", 0)),
)
callbacks = [metrics_callback]
callbacks.append(
EvalCallback(
eval_callback = EvalMetricsCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
step_offset=int(cfg.get("wandb_step_offset", 0)),
deterministic=True,
verbose=0,
)
)
callbacks = [metrics_callback, eval_callback]
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
@@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
artifact_name = checkpoint_artifact_name(
cfg,
backend="sb3",
sweep_id=os.getenv("WANDB_SWEEP_ID"),
)
artifact_logged = False
try:
artifact_logged = bool(
log_checkpoint_file(
artifact_name,
file_path=model_path.with_suffix(".zip"),
artifact_file_name="model.zip",
metadata={
"algo": str(cfg.get("algo", "ppo")),
"backend": "sb3",
"seed": int(cfg.get("seed", 0)),
"step": int(getattr(model, "num_timesteps", 0)),
},
)
)
except Exception:
artifact_logged = False
metrics: dict[str, Any] = evaluate(
model,
eval_env,
@@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["_train_events"] = list(metrics_callback.events)
metrics["model/artifact_name"] = str(artifact_name)
metrics["model/artifact_logged"] = float(artifact_logged)
metrics["_train_events"] = sorted(
[*metrics_callback.events, *eval_callback.events],
key=lambda event: int(event.get("train/global_step", 0)),
)
env.close()
eval_env.close()

View File

@@ -45,6 +45,10 @@ def _log(message: str) -> None:
logger.info(message)
def _wandb_run_active() -> bool:
return bool(HAS_WANDB and getattr(wandb, "run", None) is not None)
def _parse_list(raw: str) -> list[str]:
return [x.strip().lower() for x in str(raw).split(",") if x.strip()]
@@ -61,6 +65,10 @@ def _truthy(value: str | bool | None) -> bool:
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _mode_label_from_baseline(is_baseline: bool) -> str:
return "baseline" if bool(is_baseline) else "defended"
def _action(policy, obs: np.ndarray):
out = policy.predict(obs, deterministic=True)
action = out[0] if isinstance(out, tuple) else out
@@ -166,7 +174,7 @@ def _log_train_events(
alpha: float,
step_offset: int,
) -> int:
if not (HAS_WANDB and wandb.run is not None):
if not _wandb_run_active():
return int(step_offset)
if not events:
return int(step_offset)
@@ -187,11 +195,14 @@ def _log_train_events(
"run.kind": "benchmark",
"runtime/backend": tier_name,
"study/mode": mode_label,
"study/no_robust": float(mode_label == "no_robust"),
"study/baseline_mode": float(mode_label == "baseline"),
"study/alpha": float(alpha),
}
)
try:
wandb.log(payload, step=cursor + rel_step)
except Exception:
return int(step_offset)
max_rel = max(max(1, int(evt.get("train/global_step", 0))) for evt in ordered)
return cursor + max_rel + 1
@@ -203,6 +214,7 @@ def run_benchmark(
n_episodes: int,
mode_label: str,
step_cursor_start: int = 0,
eval_alpha_values: list[float] | None = None,
):
from .backends.common import make_env
@@ -239,14 +251,22 @@ def run_benchmark(
"dqn",
}:
wandb_step_cursor += max(1, int(cfg.get("total_timesteps", 1))) + 1
env = make_env({**cfg, "alpha": float(alpha)})
eval_targets = (
[float(value) for value in eval_alpha_values]
if eval_alpha_values
else [float(alpha)]
)
for eval_alpha in eval_targets:
env = make_env({**cfg, "alpha": float(eval_alpha)})
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
env.close()
row = {
"tier": tier_name,
"mode": mode_label,
"alpha": float(alpha),
"alpha": float(eval_alpha),
"train_alpha": float(alpha),
"eval_alpha": float(eval_alpha),
"episodes": int(n_episodes),
"mean_reward": float(np.mean([e["reward"] for e in eps])),
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
@@ -257,7 +277,8 @@ def run_benchmark(
row["objective_score"] = row["mean_reward"]
rows.append(row)
_log(
f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: "
f"[{run_index}/{total_runs}] train_alpha={float(alpha):.2f} "
f"eval_alpha={float(eval_alpha):.2f} tier={tier_name}: "
f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} "
f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}"
)
@@ -266,25 +287,32 @@ def run_benchmark(
step_means = []
for step in range(max_len):
vals = [
e["price_trace"][step] for e in eps if step < len(e["price_trace"])
e["price_trace"][step]
for e in eps
if step < len(e["price_trace"])
]
step_means.append(float(np.mean(vals)) if vals else np.nan)
traces.append(
{
"tier": tier_name,
"alpha": float(alpha),
"alpha": float(eval_alpha),
"train_alpha": float(alpha),
"eval_alpha": float(eval_alpha),
"mean_price_trace": step_means,
}
)
if HAS_WANDB and wandb.run is not None:
if _wandb_run_active():
try:
wandb.log(
{
"run.kind": "benchmark",
"runtime/backend": tier_name,
"study/mode": mode_label,
"study/no_robust": float(mode_label == "no_robust"),
"study/alpha": float(alpha),
"study/baseline_mode": float(mode_label == "baseline"),
"study/alpha": float(eval_alpha),
"study/train_alpha": float(alpha),
"study/eval_alpha": float(eval_alpha),
"eval/reward_mean": row["mean_reward"],
"eval/revenue_mean": row["mean_revenue"],
"eval/margin_mean": row["mean_margin"],
@@ -294,6 +322,8 @@ def run_benchmark(
},
step=wandb_step_cursor,
)
except Exception:
pass
wandb_step_cursor += 1
return pd.DataFrame(rows), traces, int(wandb_step_cursor)
@@ -378,7 +408,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
if compare_robust_override is not None
else _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
)
robust_modes = [False, True] if compare_robust else [bool(args.no_robust)]
baseline_modes = [False, True] if compare_robust else [bool(args.no_robust)]
base_overrides = {
"seed": args.seed,
@@ -389,6 +419,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"robust_rollouts": args.robust_rollouts,
"margin_floor": args.margin_floor,
"eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight,
"price_low": args.price_low,
@@ -405,12 +436,20 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
}
tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values)
eval_alpha_values = (
_parse_float_list(args.eval_alpha_values)
if str(getattr(args, "eval_alpha_values", "")).strip()
else []
)
_log(
"starting run "
+ json.dumps(
{
"tiers": tiers,
"alpha_values": alpha_values,
"eval_alpha_values": (
eval_alpha_values if eval_alpha_values else alpha_values
),
"episodes": int(args.episodes),
"total_timesteps": int(args.total_timesteps),
"device": str(args.device),
@@ -421,14 +460,14 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
all_frames: list[pd.DataFrame] = []
all_traces: list[dict] = []
wandb_step_cursor = 0
for no_robust in robust_modes:
for baseline_mode in baseline_modes:
overrides = dict(base_overrides)
overrides["no_robust"] = bool(no_robust)
overrides["baseline_mode"] = bool(baseline_mode)
cfg = TrainSpec.from_flat(
{k: v for k, v in overrides.items() if v is not None}
).to_flat_dict()
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
mode_label = "no_robust" if no_robust else "robust"
mode_label = _mode_label_from_baseline(bool(baseline_mode))
_log(f"mode={mode_label}: begin")
df_mode, traces_mode, wandb_step_cursor = run_benchmark(
cfg,
@@ -437,6 +476,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
args.episodes,
mode_label=mode_label,
step_cursor_start=wandb_step_cursor,
eval_alpha_values=eval_alpha_values,
)
_log(f"mode={mode_label}: complete ({len(df_mode)} rows)")
for trace in traces_mode:
@@ -465,7 +505,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
+ json.dumps(
{
"tier": best["tier"],
"mode": best.get("mode", "robust"),
"mode": best.get("mode", "defended"),
"alpha": float(best["alpha"]),
"objective_score": float(best["objective_score"]),
"mean_revenue": float(best["mean_revenue"]),
@@ -486,6 +526,7 @@ def run_cli(raw_args: list[str] | None = None):
parser.add_argument("--project", default="capstone")
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")
parser.add_argument("--alpha-values", default="0.0,0.3,0.6")
parser.add_argument("--eval-alpha-values", default="")
parser.add_argument("--episodes", type=int, default=10)
parser.add_argument("--output-dir", default="engine/studies/results")
parser.add_argument("--seed", type=int, default=42)
@@ -496,6 +537,7 @@ def run_cli(raw_args: list[str] | None = None):
parser.add_argument("--robust-radius", type=float, default=0.15)
parser.add_argument("--robust-points", type=int, default=5)
parser.add_argument("--robust-rollouts", type=int, default=1)
parser.add_argument("--margin-floor", type=float, default=0.85)
parser.add_argument("--eta-ux", type=float, default=0.5)
parser.add_argument("--reward-profit-weight", type=float, default=1.0)
parser.add_argument("--price-low", type=float, default=10.0)
@@ -529,35 +571,47 @@ def run_cli(raw_args: list[str] | None = None):
key_to_attr = {
"tiers": "tiers",
"alpha_values": "alpha_values",
"eval_alpha_values": "eval_alpha_values",
"episodes": "episodes",
"total_timesteps": "total_timesteps",
"lambda_coi": "lambda_coi",
"robust_radius": "robust_radius",
"robust_points": "robust_points",
"robust_rollouts": "robust_rollouts",
"ambiguity_radius": "robust_radius",
"ambiguity_points": "robust_points",
"ambiguity_rollouts": "robust_rollouts",
"eta_ux": "eta_ux",
"reward_profit_weight": "reward_profit_weight",
"learning_rate": "learning_rate",
"batch_size": "batch_size",
"n_steps": "n_steps",
"baseline_mode": "no_robust",
"no_robust": "no_robust",
"margin_floor": "margin_floor",
"device": "device",
}
for key in (
"tiers",
"alpha_values",
"eval_alpha_values",
"episodes",
"total_timesteps",
"lambda_coi",
"robust_radius",
"robust_points",
"robust_rollouts",
"ambiguity_radius",
"ambiguity_points",
"ambiguity_rollouts",
"eta_ux",
"reward_profit_weight",
"learning_rate",
"batch_size",
"n_steps",
"baseline_mode",
"no_robust",
"margin_floor",
"device",
):
if key in wandb.config:
@@ -582,16 +636,16 @@ def run_cli(raw_args: list[str] | None = None):
alpha_values = _parse_float_list(args.alpha_values)
run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S")
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
compare_tag = "robust-compare" if compare_enabled else "single-mode"
compare_tag = "defended-compare" if compare_enabled else "single-mode"
modes = (
[("no_robust", True), ("robust", False)]
[("baseline", True), ("defended", False)]
if compare_enabled
else [("no_robust" if bool(args.no_robust) else "robust", bool(args.no_robust))]
else [(_mode_label_from_baseline(bool(args.no_robust)), bool(args.no_robust))]
)
run_idx = 0
for tier in tiers:
for mode_label, no_robust in modes:
for mode_label, baseline_mode in modes:
for alpha in alpha_values:
run_idx += 1
alpha_token = (
@@ -600,7 +654,7 @@ def run_cli(raw_args: list[str] | None = None):
tier_args = argparse.Namespace(**vars(args))
tier_args.tiers = tier
tier_args.alpha_values = str(float(alpha))
tier_args.no_robust = bool(no_robust)
tier_args.no_robust = bool(baseline_mode)
run = wandb.init(
project=args.project,
name=(
@@ -617,16 +671,19 @@ def run_cli(raw_args: list[str] | None = None):
"run.kind": "benchmark",
"runtime/backend": tier,
"study/mode": mode_label,
"study/no_robust": float(no_robust),
"study/baseline_mode": float(baseline_mode),
"study/alpha": float(alpha),
"tiers": tier,
"alpha_values": str(float(alpha)),
"eval_alpha_values": args.eval_alpha_values,
"episodes": args.episodes,
"total_timesteps": args.total_timesteps,
"lambda_coi": args.lambda_coi,
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"robust_rollouts": args.robust_rollouts,
"ambiguity_radius": args.robust_radius,
"ambiguity_points": args.robust_points,
"ambiguity_rollouts": args.robust_rollouts,
"margin_floor": args.margin_floor,
"baseline_mode": float(baseline_mode),
"eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight,
"learning_rate": args.learning_rate,

View File

@@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback):
self,
log_histograms: bool = False,
log_freq: int = 100,
hist_freq: int = 500,
step_offset: int = 0,
verbose: int = 0,
):
super().__init__(verbose)
self.log_histograms = log_histograms
self.log_freq = max(1, int(log_freq))
self.hist_freq = max(1, int(hist_freq))
self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._price_samples: list[float] = []
self._demand_samples: list[float] = []
self._window_sums = {
"train/revenue_mean": 0.0,
"train/margin_mean": 0.0,
@@ -74,9 +78,67 @@ class MetricsCallback(BaseCallback):
)
self._window_count += 1
def _flush(self, step: int) -> None:
if self._window_count <= 0:
def _accumulate_histograms(self, info: dict[str, Any]) -> None:
if not self.log_histograms:
return
for key in ("effective_prices", "prices"):
if key not in info:
continue
try:
values = np.asarray(info.get(key), dtype=float).reshape(-1)
except Exception:
continue
if values.size <= 0:
continue
finite_values = values[np.isfinite(values)]
if finite_values.size > 0:
self._price_samples.extend(finite_values.tolist())
break
if "demand" in info:
try:
demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1)
except Exception:
demand_values = np.array([], dtype=float)
if demand_values.size > 0:
finite_demand = demand_values[np.isfinite(demand_values)]
if finite_demand.size > 0:
self._demand_samples.extend(finite_demand.tolist())
def _flush_histograms(self, step: int, force: bool = False) -> None:
if not self.log_histograms:
return
if not force and step % self.hist_freq != 0:
return
if not self._price_samples and not self._demand_samples:
return
if self._wandb is None:
self._price_samples.clear()
self._demand_samples.clear()
return
payload: dict[str, Any] = {}
if self._price_samples:
payload["train/price_dist"] = self._wandb.Histogram(
np.asarray(self._price_samples, dtype=np.float32)
)
if self._demand_samples:
payload["train/demand_dist"] = self._wandb.Histogram(
np.asarray(self._demand_samples, dtype=np.float32)
)
if payload and self._wandb_live:
try:
self._wandb.log(payload, step=self.step_offset + int(step))
except Exception:
self._wandb_live = False
self._price_samples.clear()
self._demand_samples.clear()
def _flush(self, step: int, *, force_hist: bool = False) -> None:
if self._window_count > 0:
denom = float(self._window_count)
payload = {
key: (value / denom)
@@ -92,17 +154,24 @@ class MetricsCallback(BaseCallback):
}
payload["train/global_step"] = int(step)
if self._wandb_live:
try:
self._wandb.log(dict(payload), step=self.step_offset + int(step))
except Exception:
self._wandb_live = False
self.events.append(payload)
else:
self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0
self._flush_histograms(step=step, force=force_hist)
def _on_step(self) -> bool:
for info in self.locals.get("infos", []):
if isinstance(info, dict):
self._accumulate(info)
self._accumulate_histograms(info)
if self.num_timesteps % self.log_freq == 0:
self._flush(step=self.num_timesteps)
@@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback):
return True
def _on_training_end(self) -> None:
self._flush(step=self.num_timesteps)
self._flush(step=self.num_timesteps, force_hist=True)
class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation collector detached from logging backends."""
def __init__(
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
self,
eval_env,
eval_freq: int = 1000,
n_eval_episodes: int = 5,
step_offset: int = 0,
**kwargs,
):
super().__init__(
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
)
self._eval_revenues: list[float] = []
self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._eval_stats: dict[str, list[float]] = {
"eval/revenue_mean": [],
"eval/margin_mean": [],
"eval/coi_level_mean": [],
"eval/coi_leakage_mean": [],
"eval/volatility_mean": [],
"eval/agent_prob_mean": [],
}
self.events: list[dict[str, float | int]] = []
def _on_step(self) -> bool:
result = super()._on_step()
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
self.events.append(
{
payload: dict[str, float | int] = {
"eval/reward_mean": float(self.last_mean_reward),
"eval/revenue_mean": float(np.mean(self._eval_revenues))
if self._eval_revenues
else 0.0,
"train/global_step": int(self.num_timesteps),
}
for key, values in self._eval_stats.items():
payload[key] = float(np.mean(values)) if values else 0.0
if self._wandb_live:
try:
self._wandb.log(
dict(payload),
step=self.step_offset + int(self.num_timesteps),
)
self._eval_revenues = []
except Exception:
self._wandb_live = False
self.events.append(payload)
else:
self.events.append(payload)
for values in self._eval_stats.values():
values.clear()
return result
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
# called after each eval episode
info = locals_.get("info", {})
if "economics" in info:
self._eval_revenues.append(info["economics"]["revenue"])
econ = info.get("economics") if isinstance(info, dict) else None
if not isinstance(econ, dict):
return
self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0)))
self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0)))
self._eval_stats["eval/coi_level_mean"].append(
float(econ.get("coi_level", 0.0))
)
self._eval_stats["eval/coi_leakage_mean"].append(
float(econ.get("coi_leakage", 0.0))
)
self._eval_stats["eval/volatility_mean"].append(
float(econ.get("volatility", 0.0))
)
self._eval_stats["eval/agent_prob_mean"].append(
float(econ.get("agent_prob", 0.0))
)

View File

@@ -156,6 +156,7 @@ class ProviderBenchmark:
# log to wandb if available
if HAS_WANDB and wandb.run is not None:
try:
wandb.log(
{
f"benchmark/{name}/revenue": result.mean_revenue,
@@ -164,6 +165,8 @@ class ProviderBenchmark:
"benchmark/alpha": alpha,
}
)
except Exception:
pass
return self.results

View File

@@ -9,6 +9,7 @@ from ..telemetry.wandb import (
get_wandb_module,
init_run,
run_agent,
update_summary,
)
from .train import run_with_active_sweep_run
@@ -43,6 +44,7 @@ def run_sweep_agent(
spec = TrainSpec.from_flat(merged)
if run is not None:
run.name = run_name(spec, kind=kind, scenario=scenario)
try:
run_with_active_sweep_run(
spec,
kind=kind,
@@ -50,6 +52,15 @@ def run_sweep_agent(
group=group,
extra_tags=extra_tags,
)
update_summary({"run/status": "finished"})
except Exception as exc:
update_summary(
{
"run/status": "crashed",
"run/error": str(exc),
}
)
raise
finally:
finish_run()

View File

@@ -20,7 +20,7 @@ def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list
kind,
spec.algorithm.name,
spec.runtime.backend,
"vanilla" if spec.study.no_robust else "robust",
"baseline" if spec.study.no_robust else "defended",
]
tags.extend([tag for tag in extra_tags if tag])
return tags

View File

@@ -91,6 +91,44 @@
"command": "bash scripts/nx_research.sh docker-train-publish",
"cwd": "."
}
},
"whoclicked-publish": {
"executor": "nx:run-commands",
"dependsOn": [
"install"
],
"options": {
"command": "bash scripts/nx_research.sh whoclicked-publish",
"cwd": "."
}
},
"tpu-ray-bootstrap": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-bootstrap",
"cwd": "."
}
},
"tpu-ray-deps": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-deps",
"cwd": "."
}
},
"tpu-ray-verify": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-verify",
"cwd": "."
}
},
"tpu-ray-teardown": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-teardown",
"cwd": "."
}
}
},
"tags": [

View File

@@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"study.robust_radius": "robust_radius",
"study.robust_points": "robust_points",
"study.robust_rollouts": "robust_rollouts",
"study.ambiguity_radius": "robust_radius",
"study.ambiguity_points": "robust_points",
"study.ambiguity_rollouts": "robust_rollouts",
"study.info_value": "info_value",
"study.eta_ux": "eta_ux",
"study.reward_profit_weight": "reward_profit_weight",
"study.revenue_weight": "revenue_weight",
"ambiguity_radius": "robust_radius",
"ambiguity_points": "robust_points",
"ambiguity_rollouts": "robust_rollouts",
"baseline_mode": "no_robust",
"stress_eval_enabled": "robust_eval_enabled",
"optimizer.learning_rate": "learning_rate",
"optimizer.gamma": "gamma",
"optimizer.batch_size": "batch_size",
@@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"runtime.seed": "seed",
"runtime.total_timesteps": "total_timesteps",
"runtime.checkpoint_interval": "checkpoint_interval",
"runtime.hist_freq": "hist_freq",
"eval.eval_freq": "eval_freq",
"eval.eval_episodes": "eval_episodes",
}
@@ -86,7 +94,6 @@ class StudySpec:
info_value: float = 1.0
eta_ux: float = 0.5
reward_profit_weight: float = 1.0
revenue_weight: float = 0.01
no_robust: bool = False
@@ -128,6 +135,7 @@ class RuntimeSpec:
checkpoint_interval: int = 200_000
model_dir: str = "engine/models"
log_freq: int = 100
hist_freq: int = 500
@dataclass(frozen=True)
@@ -159,6 +167,7 @@ class TrainSpec:
"backend": self.runtime.backend,
"device": self.runtime.device,
"checkpoint_interval": self.runtime.checkpoint_interval,
"hist_freq": self.runtime.hist_freq,
"n_products": self.env.n_products,
"N": self.env.n_sessions,
"price_low": self.env.price_low,
@@ -179,7 +188,6 @@ class TrainSpec:
"info_value": self.study.info_value,
"eta_ux": self.study.eta_ux,
"reward_profit_weight": self.study.reward_profit_weight,
"revenue_weight": self.study.revenue_weight,
"no_robust": self.study.no_robust,
"learning_rate": self.optimizer.learning_rate,
"gamma": self.optimizer.gamma,
@@ -262,7 +270,6 @@ class TrainSpec:
info_value=float(base["info_value"]),
eta_ux=float(base["eta_ux"]),
reward_profit_weight=float(base["reward_profit_weight"]),
revenue_weight=float(base["revenue_weight"]),
no_robust=no_robust,
),
optimizer=OptimizerSpec(
@@ -300,6 +307,7 @@ class TrainSpec:
checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]),
hist_freq=int(base["hist_freq"]),
),
eval=EvalSpec(
eval_freq=int(base["eval_freq"]),
@@ -310,9 +318,11 @@ class TrainSpec:
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".")
mode = "baseline" if bool(spec.study.no_robust) else "defended"
return (
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}"
)
@@ -324,6 +334,7 @@ def run_metadata(
group: str | None = None,
tags: Sequence[str] = (),
) -> dict[str, Any]:
mode = "baseline" if bool(spec.study.no_robust) else "defended"
metadata: dict[str, Any] = {
"run.kind": str(kind),
"run.algo": spec.algorithm.name,
@@ -332,6 +343,10 @@ def run_metadata(
"run.scenario": str(scenario),
"run.seed": spec.runtime.seed,
"run.tags": list(tags),
"study/alpha": float(spec.study.alpha),
"study/mode": mode,
"study/baseline_mode": float(bool(spec.study.no_robust)),
"tiers": spec.algorithm.name,
}
if group:
metadata["run.group"] = group

View File

@@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
eval_reward = (
_as_float(
metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")),
metrics.get(
"eval/stress_reward_worst",
metrics.get(
"eval/robust_reward_worst", metrics.get("eval/reward_mean")
),
),
0.0,
)
or 0.0
@@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
metrics["study/alpha"] = spec.study.alpha
metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended"
metrics["study/baseline_mode"] = float(bool(spec.study.no_robust))
metrics["study/lambda_coi"] = spec.study.lambda_coi
metrics["study/robust_radius"] = spec.study.robust_radius
metrics["study/ambiguity_radius"] = spec.study.robust_radius
metrics["study/info_value"] = spec.study.info_value
metrics["tiers"] = spec.algorithm.name
metrics["runtime/backend"] = spec.runtime.backend
metrics["runtime/device"] = spec.runtime.device

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import os
import time
from typing import Any, Callable, Iterable, Mapping
@@ -19,6 +21,42 @@ def _require_wandb():
return wandb
def _warn(message: str) -> None:
print(f"PHANTOM_WANDB_WARNING: {message}")
def _sanitize_key(raw_key: str) -> str | None:
key = str(raw_key)
replacements = {
"no_robust": "baseline_mode",
"study/no_robust": "study/baseline_mode",
"study/robust_radius": "study/ambiguity_radius",
"robust_radius": "ambiguity_radius",
"robust_points": "ambiguity_points",
"robust_rollouts": "ambiguity_rollouts",
"robust_eval_enabled": "stress_eval_enabled",
"eval/robust_alpha_high": "eval/stress_alpha_high",
"eval/robust_alpha_low": "eval/stress_alpha_low",
"eval/robust_reward_worst": "eval/stress_reward_worst",
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
}
key = replacements.get(key, key)
if "robust" in key.lower():
return None
return key
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
sanitized: dict[str, Any] = {}
for key, value in payload.items():
clean_key = _sanitize_key(str(key))
if clean_key is None:
continue
sanitized[clean_key] = value
return sanitized
def init_run(
*,
mode: str,
@@ -34,7 +72,11 @@ def init_run(
if group:
kwargs["group"] = group
if sweep_mode:
try:
run = wandb.init(**kwargs)
except Exception as exc:
_warn(f"init failed in sweep mode ({exc})")
return None
if name and run is not None:
run.name = name
return run
@@ -42,18 +84,25 @@ def init_run(
init_kwargs = dict(kwargs)
init_kwargs["project"] = project
if config is not None:
init_kwargs["config"] = dict(config)
init_kwargs["config"] = _sanitize_payload(dict(config))
if name:
init_kwargs["name"] = name
if tags:
init_kwargs["tags"] = list(tags)
try:
return wandb.init(**init_kwargs)
except Exception as exc:
_warn(f"init failed ({exc})")
return None
def finish_run() -> None:
wandb = get_wandb_module()
if wandb is not None and wandb.run is not None:
try:
wandb.finish()
except Exception as exc:
_warn(f"finish failed ({exc})")
def current_config() -> dict[str, Any]:
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
payload = _sanitize_payload(dict(config))
if not payload:
return
try:
wandb.config.update(dict(config), allow_val_change=True)
wandb.config.update(payload, allow_val_change=True)
except TypeError:
wandb.config.update(dict(config))
try:
wandb.config.update(payload)
except Exception as exc:
_warn(f"config update failed ({exc})")
except Exception as exc:
_warn(f"config update failed ({exc})")
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
wandb.log(dict(metrics), step=step)
payload = _sanitize_payload(dict(metrics))
if not payload:
return
try:
wandb.log(payload, step=step)
except Exception as exc:
_warn(f"log failed at step {step} ({exc})")
def update_summary(metrics: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
for key, value in metrics.items():
payload = _sanitize_payload(dict(metrics))
if not payload:
return
try:
for key, value in payload.items():
wandb.run.summary[key] = value
except Exception as exc:
_warn(f"summary update failed ({exc})")
def run_agent(
@@ -95,4 +164,39 @@ def run_agent(
count: int | None = None,
) -> None:
wandb = _require_wandb()
wandb.agent(sweep_id, function=fn, count=count)
retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
retry_backoff = max(
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
)
retry_max_delay = max(
retry_delay,
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
)
target = None if count is None else max(0, int(count))
completed = 0
def _wrapped() -> None:
nonlocal completed
fn()
completed += 1
attempt = 0
while True:
remaining = None if target is None else max(0, int(target - completed))
if target is not None and remaining == 0:
return
try:
wandb.agent(sweep_id, function=_wrapped, count=remaining)
return
except Exception as exc:
attempt += 1
if attempt > retry_max:
raise
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
_warn(
f"agent disconnected (attempt {attempt}/{retry_max}, "
f"completed={completed}, remaining={remaining}): {exc}"
)
time.sleep(wait)

View File

@@ -54,6 +54,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--total-timesteps", type=int)
parser.add_argument("--model-dir", type=str)
parser.add_argument("--log-freq", type=int)
parser.add_argument("--hist-freq", type=int)
parser.add_argument("--checkpoint-interval", type=int)
parser.add_argument("--device", type=str)
@@ -68,7 +69,6 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--no-robust", action="store_true")
parser.add_argument("--eta-ux", type=float)
parser.add_argument("--reward-profit-weight", type=float)
parser.add_argument("--revenue-weight", type=float)
parser.add_argument("--price-low", type=float)
parser.add_argument("--price-high", type=float)
@@ -126,6 +126,7 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"total_timesteps": args.total_timesteps,
"model_dir": args.model_dir,
"log_freq": args.log_freq,
"hist_freq": args.hist_freq,
"checkpoint_interval": args.checkpoint_interval,
"device": args.device,
"alpha": args.alpha,
@@ -139,7 +140,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"no_robust": args.no_robust,
"eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight,
"revenue_weight": args.revenue_weight,
"price_low": args.price_low,
"price_high": args.price_high,
"action_levels": args.action_levels,