From 52b4dcdce3cca5099ddcbd6458af46b50185f935 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Sun, 15 Mar 2026 21:14:11 +0100 Subject: [PATCH] updating engine training for training --- engine/backends/common.py | 10 +- engine/backends/qtable.py | 12 +- engine/backends/sb3.py | 55 ++++++-- engine/benchmark.py | 193 ++++++++++++++++++---------- engine/lib/callbacks.py | 185 ++++++++++++++++++++------ engine/lib/providers.py | 19 +-- engine/orchestrators/sweep_agent.py | 25 +++- engine/orchestrators/train.py | 2 +- engine/project.json | 38 ++++++ engine/spec.py | 25 +++- engine/telemetry/metrics.py | 12 +- engine/telemetry/wandb.py | 124 ++++++++++++++++-- engine/train.py | 4 +- 13 files changed, 544 insertions(+), 160 deletions(-) diff --git a/engine/backends/common.py b/engine/backends/common.py index ca508f7..45f03e7 100644 --- a/engine/backends/common.py +++ b/engine/backends/common.py @@ -132,15 +132,15 @@ def evaluate( shifted_env.close() shifted_rows.append((tag, alpha, shifted_metrics)) - metrics["eval/robust_alpha_low"] = low_alpha - metrics["eval/robust_alpha_high"] = high_alpha - metrics["eval/robust_reward_worst"] = float( + metrics["eval/stress_alpha_low"] = low_alpha + metrics["eval/stress_alpha_high"] = high_alpha + metrics["eval/stress_reward_worst"] = float( min(row[2]["eval/reward_mean"] for row in shifted_rows) ) - metrics["eval/robust_revenue_worst"] = float( + metrics["eval/stress_revenue_worst"] = float( min(row[2]["eval/revenue_mean"] for row in shifted_rows) ) - metrics["eval/robust_coi_leakage_worst"] = float( + metrics["eval/stress_coi_leakage_worst"] = float( max(row[2]["eval/coi_leakage_mean"] for row in shifted_rows) ) for tag, alpha, shifted_metrics in shifted_rows: diff --git a/engine/backends/qtable.py b/engine/backends/qtable.py index b314fdb..cfb79d1 100644 --- a/engine/backends/qtable.py +++ b/engine/backends/qtable.py @@ -80,7 +80,11 @@ def train_qtable( "train/global_step": int(steps), } if wandb_live: - wandb.log(dict(event), step=step_offset + int(steps)) + try: + wandb.log(dict(event), step=step_offset + int(steps)) + except Exception: + wandb_live = False + train_events.append(event) else: train_events.append(event) if console_progress: @@ -113,7 +117,11 @@ def train_qtable( "train/global_step": int(steps), } if wandb_live: - wandb.log(dict(tail_event), step=step_offset + int(steps)) + try: + wandb.log(dict(tail_event), step=step_offset + int(steps)) + except Exception: + wandb_live = False + train_events.append(tail_event) else: train_events.append(tail_event) diff --git a/engine/backends/sb3.py b/engine/backends/sb3.py index 37f23c5..7a62d81 100644 --- a/engine/backends/sb3.py +++ b/engine/backends/sb3.py @@ -1,10 +1,12 @@ from __future__ import annotations import json +import os from pathlib import Path from typing import Any, Mapping -from ..lib.callbacks import MetricsCallback +from ..lib.callbacks import EvalMetricsCallback, MetricsCallback +from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file from .common import evaluate, make_env @@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any): def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: try: - from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor except ImportError as exc: raise ImportError("stable-baselines3 is required for SB3 models") from exc @@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: pass metrics_callback = MetricsCallback( - log_histograms=False, + log_histograms=True, log_freq=int(cfg["log_freq"]), + hist_freq=int(cfg.get("hist_freq", 500)), step_offset=int(cfg.get("wandb_step_offset", 0)), ) - callbacks = [metrics_callback] - callbacks.append( - EvalCallback( - eval_env, - eval_freq=int(cfg["eval_freq"]), - n_eval_episodes=int(cfg["eval_episodes"]), - deterministic=True, - verbose=0, - ) + eval_callback = EvalMetricsCallback( + eval_env, + eval_freq=int(cfg["eval_freq"]), + n_eval_episodes=int(cfg["eval_episodes"]), + step_offset=int(cfg.get("wandb_step_offset", 0)), + deterministic=True, + verbose=0, ) + callbacks = [metrics_callback, eval_callback] target_steps = int(cfg["total_timesteps"]) remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0))) @@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: model_path = model_dir / f"phantom_{cfg['algo']}" model.save(str(model_path)) + artifact_name = checkpoint_artifact_name( + cfg, + backend="sb3", + sweep_id=os.getenv("WANDB_SWEEP_ID"), + ) + artifact_logged = False + try: + artifact_logged = bool( + log_checkpoint_file( + artifact_name, + file_path=model_path.with_suffix(".zip"), + artifact_file_name="model.zip", + metadata={ + "algo": str(cfg.get("algo", "ppo")), + "backend": "sb3", + "seed": int(cfg.get("seed", 0)), + "step": int(getattr(model, "num_timesteps", 0)), + }, + ) + ) + except Exception: + artifact_logged = False + metrics: dict[str, Any] = evaluate( model, eval_env, @@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: ) metrics["train/global_step"] = int(model.num_timesteps) metrics["model/path"] = str(model_path.with_suffix(".zip")) - metrics["_train_events"] = list(metrics_callback.events) + metrics["model/artifact_name"] = str(artifact_name) + metrics["model/artifact_logged"] = float(artifact_logged) + metrics["_train_events"] = sorted( + [*metrics_callback.events, *eval_callback.events], + key=lambda event: int(event.get("train/global_step", 0)), + ) env.close() eval_env.close() diff --git a/engine/benchmark.py b/engine/benchmark.py index fc0205f..1cc6acc 100644 --- a/engine/benchmark.py +++ b/engine/benchmark.py @@ -45,6 +45,10 @@ def _log(message: str) -> None: logger.info(message) +def _wandb_run_active() -> bool: + return bool(HAS_WANDB and getattr(wandb, "run", None) is not None) + + def _parse_list(raw: str) -> list[str]: return [x.strip().lower() for x in str(raw).split(",") if x.strip()] @@ -61,6 +65,10 @@ def _truthy(value: str | bool | None) -> bool: return str(value).strip().lower() in {"1", "true", "yes", "on"} +def _mode_label_from_baseline(is_baseline: bool) -> str: + return "baseline" if bool(is_baseline) else "defended" + + def _action(policy, obs: np.ndarray): out = policy.predict(obs, deterministic=True) action = out[0] if isinstance(out, tuple) else out @@ -166,7 +174,7 @@ def _log_train_events( alpha: float, step_offset: int, ) -> int: - if not (HAS_WANDB and wandb.run is not None): + if not _wandb_run_active(): return int(step_offset) if not events: return int(step_offset) @@ -187,11 +195,14 @@ def _log_train_events( "run.kind": "benchmark", "runtime/backend": tier_name, "study/mode": mode_label, - "study/no_robust": float(mode_label == "no_robust"), + "study/baseline_mode": float(mode_label == "baseline"), "study/alpha": float(alpha), } ) - wandb.log(payload, step=cursor + rel_step) + try: + wandb.log(payload, step=cursor + rel_step) + except Exception: + return int(step_offset) max_rel = max(max(1, int(evt.get("train/global_step", 0))) for evt in ordered) return cursor + max_rel + 1 @@ -203,6 +214,7 @@ def run_benchmark( n_episodes: int, mode_label: str, step_cursor_start: int = 0, + eval_alpha_values: list[float] | None = None, ): from .backends.common import make_env @@ -239,62 +251,80 @@ def run_benchmark( "dqn", }: wandb_step_cursor += max(1, int(cfg.get("total_timesteps", 1))) + 1 - env = make_env({**cfg, "alpha": float(alpha)}) - eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))] - env.close() - - row = { - "tier": tier_name, - "mode": mode_label, - "alpha": float(alpha), - "episodes": int(n_episodes), - "mean_reward": float(np.mean([e["reward"] for e in eps])), - "mean_revenue": float(np.mean([e["revenue"] for e in eps])), - "mean_margin": float(np.mean([e["mean_margin"] for e in eps])), - "mean_coi": float(np.mean([e["mean_coi"] for e in eps])), - "std_revenue": float(np.std([e["revenue"] for e in eps])), - } - row["objective_score"] = row["mean_reward"] - rows.append(row) - _log( - f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: " - f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} " - f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}" + eval_targets = ( + [float(value) for value in eval_alpha_values] + if eval_alpha_values + else [float(alpha)] ) + for eval_alpha in eval_targets: + env = make_env({**cfg, "alpha": float(eval_alpha)}) + eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))] + env.close() - max_len = max((len(e["price_trace"]) for e in eps), default=0) - step_means = [] - for step in range(max_len): - vals = [ - e["price_trace"][step] for e in eps if step < len(e["price_trace"]) - ] - step_means.append(float(np.mean(vals)) if vals else np.nan) - traces.append( - { + row = { "tier": tier_name, - "alpha": float(alpha), - "mean_price_trace": step_means, + "mode": mode_label, + "alpha": float(eval_alpha), + "train_alpha": float(alpha), + "eval_alpha": float(eval_alpha), + "episodes": int(n_episodes), + "mean_reward": float(np.mean([e["reward"] for e in eps])), + "mean_revenue": float(np.mean([e["revenue"] for e in eps])), + "mean_margin": float(np.mean([e["mean_margin"] for e in eps])), + "mean_coi": float(np.mean([e["mean_coi"] for e in eps])), + "std_revenue": float(np.std([e["revenue"] for e in eps])), } - ) - - if HAS_WANDB and wandb.run is not None: - wandb.log( - { - "run.kind": "benchmark", - "runtime/backend": tier_name, - "study/mode": mode_label, - "study/no_robust": float(mode_label == "no_robust"), - "study/alpha": float(alpha), - "eval/reward_mean": row["mean_reward"], - "eval/revenue_mean": row["mean_revenue"], - "eval/margin_mean": row["mean_margin"], - "eval/coi_level_mean": row["mean_coi"], - "objective/score": row["objective_score"], - "objective/coi_preserved": row["mean_coi"], - }, - step=wandb_step_cursor, + row["objective_score"] = row["mean_reward"] + rows.append(row) + _log( + f"[{run_index}/{total_runs}] train_alpha={float(alpha):.2f} " + f"eval_alpha={float(eval_alpha):.2f} tier={tier_name}: " + f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} " + f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}" ) - wandb_step_cursor += 1 + + max_len = max((len(e["price_trace"]) for e in eps), default=0) + step_means = [] + for step in range(max_len): + vals = [ + e["price_trace"][step] + for e in eps + if step < len(e["price_trace"]) + ] + step_means.append(float(np.mean(vals)) if vals else np.nan) + traces.append( + { + "tier": tier_name, + "alpha": float(eval_alpha), + "train_alpha": float(alpha), + "eval_alpha": float(eval_alpha), + "mean_price_trace": step_means, + } + ) + + if _wandb_run_active(): + try: + wandb.log( + { + "run.kind": "benchmark", + "runtime/backend": tier_name, + "study/mode": mode_label, + "study/baseline_mode": float(mode_label == "baseline"), + "study/alpha": float(eval_alpha), + "study/train_alpha": float(alpha), + "study/eval_alpha": float(eval_alpha), + "eval/reward_mean": row["mean_reward"], + "eval/revenue_mean": row["mean_revenue"], + "eval/margin_mean": row["mean_margin"], + "eval/coi_level_mean": row["mean_coi"], + "objective/score": row["objective_score"], + "objective/coi_preserved": row["mean_coi"], + }, + step=wandb_step_cursor, + ) + except Exception: + pass + wandb_step_cursor += 1 return pd.DataFrame(rows), traces, int(wandb_step_cursor) @@ -378,7 +408,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None): if compare_robust_override is not None else _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) ) - robust_modes = [False, True] if compare_robust else [bool(args.no_robust)] + baseline_modes = [False, True] if compare_robust else [bool(args.no_robust)] base_overrides = { "seed": args.seed, @@ -389,6 +419,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None): "robust_radius": args.robust_radius, "robust_points": args.robust_points, "robust_rollouts": args.robust_rollouts, + "margin_floor": args.margin_floor, "eta_ux": args.eta_ux, "reward_profit_weight": args.reward_profit_weight, "price_low": args.price_low, @@ -405,12 +436,20 @@ def _run_with_args(args, compare_robust_override: bool | None = None): } tiers = _parse_list(args.tiers) alpha_values = _parse_float_list(args.alpha_values) + eval_alpha_values = ( + _parse_float_list(args.eval_alpha_values) + if str(getattr(args, "eval_alpha_values", "")).strip() + else [] + ) _log( "starting run " + json.dumps( { "tiers": tiers, "alpha_values": alpha_values, + "eval_alpha_values": ( + eval_alpha_values if eval_alpha_values else alpha_values + ), "episodes": int(args.episodes), "total_timesteps": int(args.total_timesteps), "device": str(args.device), @@ -421,14 +460,14 @@ def _run_with_args(args, compare_robust_override: bool | None = None): all_frames: list[pd.DataFrame] = [] all_traces: list[dict] = [] wandb_step_cursor = 0 - for no_robust in robust_modes: + for baseline_mode in baseline_modes: overrides = dict(base_overrides) - overrides["no_robust"] = bool(no_robust) + overrides["baseline_mode"] = bool(baseline_mode) cfg = TrainSpec.from_flat( {k: v for k, v in overrides.items() if v is not None} ).to_flat_dict() cfg["linear_warmup_steps"] = int(args.linear_warmup_steps) - mode_label = "no_robust" if no_robust else "robust" + mode_label = _mode_label_from_baseline(bool(baseline_mode)) _log(f"mode={mode_label}: begin") df_mode, traces_mode, wandb_step_cursor = run_benchmark( cfg, @@ -437,6 +476,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None): args.episodes, mode_label=mode_label, step_cursor_start=wandb_step_cursor, + eval_alpha_values=eval_alpha_values, ) _log(f"mode={mode_label}: complete ({len(df_mode)} rows)") for trace in traces_mode: @@ -465,7 +505,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None): + json.dumps( { "tier": best["tier"], - "mode": best.get("mode", "robust"), + "mode": best.get("mode", "defended"), "alpha": float(best["alpha"]), "objective_score": float(best["objective_score"]), "mean_revenue": float(best["mean_revenue"]), @@ -486,6 +526,7 @@ def run_cli(raw_args: list[str] | None = None): parser.add_argument("--project", default="capstone") parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo") parser.add_argument("--alpha-values", default="0.0,0.3,0.6") + parser.add_argument("--eval-alpha-values", default="") parser.add_argument("--episodes", type=int, default=10) parser.add_argument("--output-dir", default="engine/studies/results") parser.add_argument("--seed", type=int, default=42) @@ -496,6 +537,7 @@ def run_cli(raw_args: list[str] | None = None): parser.add_argument("--robust-radius", type=float, default=0.15) parser.add_argument("--robust-points", type=int, default=5) parser.add_argument("--robust-rollouts", type=int, default=1) + parser.add_argument("--margin-floor", type=float, default=0.85) parser.add_argument("--eta-ux", type=float, default=0.5) parser.add_argument("--reward-profit-weight", type=float, default=1.0) parser.add_argument("--price-low", type=float, default=10.0) @@ -529,35 +571,47 @@ def run_cli(raw_args: list[str] | None = None): key_to_attr = { "tiers": "tiers", "alpha_values": "alpha_values", + "eval_alpha_values": "eval_alpha_values", "episodes": "episodes", "total_timesteps": "total_timesteps", "lambda_coi": "lambda_coi", "robust_radius": "robust_radius", "robust_points": "robust_points", "robust_rollouts": "robust_rollouts", + "ambiguity_radius": "robust_radius", + "ambiguity_points": "robust_points", + "ambiguity_rollouts": "robust_rollouts", "eta_ux": "eta_ux", "reward_profit_weight": "reward_profit_weight", "learning_rate": "learning_rate", "batch_size": "batch_size", "n_steps": "n_steps", + "baseline_mode": "no_robust", "no_robust": "no_robust", + "margin_floor": "margin_floor", "device": "device", } for key in ( "tiers", "alpha_values", + "eval_alpha_values", "episodes", "total_timesteps", "lambda_coi", "robust_radius", "robust_points", "robust_rollouts", + "ambiguity_radius", + "ambiguity_points", + "ambiguity_rollouts", "eta_ux", "reward_profit_weight", "learning_rate", "batch_size", "n_steps", + "baseline_mode", "no_robust", + "margin_floor", "device", ): if key in wandb.config: @@ -582,16 +636,16 @@ def run_cli(raw_args: list[str] | None = None): alpha_values = _parse_float_list(args.alpha_values) run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S") compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) - compare_tag = "robust-compare" if compare_enabled else "single-mode" + compare_tag = "defended-compare" if compare_enabled else "single-mode" modes = ( - [("no_robust", True), ("robust", False)] + [("baseline", True), ("defended", False)] if compare_enabled - else [("no_robust" if bool(args.no_robust) else "robust", bool(args.no_robust))] + else [(_mode_label_from_baseline(bool(args.no_robust)), bool(args.no_robust))] ) run_idx = 0 for tier in tiers: - for mode_label, no_robust in modes: + for mode_label, baseline_mode in modes: for alpha in alpha_values: run_idx += 1 alpha_token = ( @@ -600,7 +654,7 @@ def run_cli(raw_args: list[str] | None = None): tier_args = argparse.Namespace(**vars(args)) tier_args.tiers = tier tier_args.alpha_values = str(float(alpha)) - tier_args.no_robust = bool(no_robust) + tier_args.no_robust = bool(baseline_mode) run = wandb.init( project=args.project, name=( @@ -617,16 +671,19 @@ def run_cli(raw_args: list[str] | None = None): "run.kind": "benchmark", "runtime/backend": tier, "study/mode": mode_label, - "study/no_robust": float(no_robust), + "study/baseline_mode": float(baseline_mode), "study/alpha": float(alpha), "tiers": tier, "alpha_values": str(float(alpha)), + "eval_alpha_values": args.eval_alpha_values, "episodes": args.episodes, "total_timesteps": args.total_timesteps, "lambda_coi": args.lambda_coi, - "robust_radius": args.robust_radius, - "robust_points": args.robust_points, - "robust_rollouts": args.robust_rollouts, + "ambiguity_radius": args.robust_radius, + "ambiguity_points": args.robust_points, + "ambiguity_rollouts": args.robust_rollouts, + "margin_floor": args.margin_floor, + "baseline_mode": float(baseline_mode), "eta_ux": args.eta_ux, "reward_profit_weight": args.reward_profit_weight, "learning_rate": args.learning_rate, diff --git a/engine/lib/callbacks.py b/engine/lib/callbacks.py index 2193894..ec5c6ef 100644 --- a/engine/lib/callbacks.py +++ b/engine/lib/callbacks.py @@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback): self, log_histograms: bool = False, log_freq: int = 100, + hist_freq: int = 500, step_offset: int = 0, verbose: int = 0, ): super().__init__(verbose) self.log_histograms = log_histograms self.log_freq = max(1, int(log_freq)) + self.hist_freq = max(1, int(hist_freq)) self.step_offset = max(0, int(step_offset)) self._wandb = get_wandb_module() self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None) + self._price_samples: list[float] = [] + self._demand_samples: list[float] = [] self._window_sums = { "train/revenue_mean": 0.0, "train/margin_mean": 0.0, @@ -74,35 +78,100 @@ class MetricsCallback(BaseCallback): ) self._window_count += 1 - def _flush(self, step: int) -> None: - if self._window_count <= 0: + def _accumulate_histograms(self, info: dict[str, Any]) -> None: + if not self.log_histograms: return - denom = float(self._window_count) - payload = { - key: (value / denom) - for key, value in self._window_sums.items() - if value != 0.0 - or key - in { - "train/revenue_mean", - "train/margin_mean", - "train/coi_level_mean", - "train/regret_mean", + + for key in ("effective_prices", "prices"): + if key not in info: + continue + try: + values = np.asarray(info.get(key), dtype=float).reshape(-1) + except Exception: + continue + if values.size <= 0: + continue + finite_values = values[np.isfinite(values)] + if finite_values.size > 0: + self._price_samples.extend(finite_values.tolist()) + break + + if "demand" in info: + try: + demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1) + except Exception: + demand_values = np.array([], dtype=float) + if demand_values.size > 0: + finite_demand = demand_values[np.isfinite(demand_values)] + if finite_demand.size > 0: + self._demand_samples.extend(finite_demand.tolist()) + + def _flush_histograms(self, step: int, force: bool = False) -> None: + if not self.log_histograms: + return + if not force and step % self.hist_freq != 0: + return + if not self._price_samples and not self._demand_samples: + return + if self._wandb is None: + self._price_samples.clear() + self._demand_samples.clear() + return + + payload: dict[str, Any] = {} + if self._price_samples: + payload["train/price_dist"] = self._wandb.Histogram( + np.asarray(self._price_samples, dtype=np.float32) + ) + if self._demand_samples: + payload["train/demand_dist"] = self._wandb.Histogram( + np.asarray(self._demand_samples, dtype=np.float32) + ) + + if payload and self._wandb_live: + try: + self._wandb.log(payload, step=self.step_offset + int(step)) + except Exception: + self._wandb_live = False + + self._price_samples.clear() + self._demand_samples.clear() + + def _flush(self, step: int, *, force_hist: bool = False) -> None: + if self._window_count > 0: + denom = float(self._window_count) + payload = { + key: (value / denom) + for key, value in self._window_sums.items() + if value != 0.0 + or key + in { + "train/revenue_mean", + "train/margin_mean", + "train/coi_level_mean", + "train/regret_mean", + } } - } - payload["train/global_step"] = int(step) - if self._wandb_live: - self._wandb.log(dict(payload), step=self.step_offset + int(step)) - else: - self.events.append(payload) - for key in self._window_sums: - self._window_sums[key] = 0.0 - self._window_count = 0 + payload["train/global_step"] = int(step) + if self._wandb_live: + try: + self._wandb.log(dict(payload), step=self.step_offset + int(step)) + except Exception: + self._wandb_live = False + self.events.append(payload) + else: + self.events.append(payload) + for key in self._window_sums: + self._window_sums[key] = 0.0 + self._window_count = 0 + + self._flush_histograms(step=step, force=force_hist) def _on_step(self) -> bool: for info in self.locals.get("infos", []): if isinstance(info, dict): self._accumulate(info) + self._accumulate_histograms(info) if self.num_timesteps % self.log_freq == 0: self._flush(step=self.num_timesteps) @@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback): return True def _on_training_end(self) -> None: - self._flush(step=self.num_timesteps) + self._flush(step=self.num_timesteps, force_hist=True) class EvalMetricsCallback(EvalCallback): """Deterministic evaluation collector detached from logging backends.""" def __init__( - self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs + self, + eval_env, + eval_freq: int = 1000, + n_eval_episodes: int = 5, + step_offset: int = 0, + **kwargs, ): super().__init__( eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs ) - self._eval_revenues: list[float] = [] + self.step_offset = max(0, int(step_offset)) + self._wandb = get_wandb_module() + self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None) + self._eval_stats: dict[str, list[float]] = { + "eval/revenue_mean": [], + "eval/margin_mean": [], + "eval/coi_level_mean": [], + "eval/coi_leakage_mean": [], + "eval/volatility_mean": [], + "eval/agent_prob_mean": [], + } self.events: list[dict[str, float | int]] = [] def _on_step(self) -> bool: result = super()._on_step() if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): - self.events.append( - { - "eval/reward_mean": float(self.last_mean_reward), - "eval/revenue_mean": float(np.mean(self._eval_revenues)) - if self._eval_revenues - else 0.0, - "train/global_step": int(self.num_timesteps), - } - ) - self._eval_revenues = [] + payload: dict[str, float | int] = { + "eval/reward_mean": float(self.last_mean_reward), + "train/global_step": int(self.num_timesteps), + } + for key, values in self._eval_stats.items(): + payload[key] = float(np.mean(values)) if values else 0.0 + + if self._wandb_live: + try: + self._wandb.log( + dict(payload), + step=self.step_offset + int(self.num_timesteps), + ) + except Exception: + self._wandb_live = False + self.events.append(payload) + else: + self.events.append(payload) + + for values in self._eval_stats.values(): + values.clear() return result def _log_success_callback(self, locals_: dict, globals_: dict) -> None: # called after each eval episode info = locals_.get("info", {}) - if "economics" in info: - self._eval_revenues.append(info["economics"]["revenue"]) + econ = info.get("economics") if isinstance(info, dict) else None + if not isinstance(econ, dict): + return + + self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0))) + self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0))) + self._eval_stats["eval/coi_level_mean"].append( + float(econ.get("coi_level", 0.0)) + ) + self._eval_stats["eval/coi_leakage_mean"].append( + float(econ.get("coi_leakage", 0.0)) + ) + self._eval_stats["eval/volatility_mean"].append( + float(econ.get("volatility", 0.0)) + ) + self._eval_stats["eval/agent_prob_mean"].append( + float(econ.get("agent_prob", 0.0)) + ) diff --git a/engine/lib/providers.py b/engine/lib/providers.py index 19d2788..2fa6d8f 100644 --- a/engine/lib/providers.py +++ b/engine/lib/providers.py @@ -156,14 +156,17 @@ class ProviderBenchmark: # log to wandb if available if HAS_WANDB and wandb.run is not None: - wandb.log( - { - f"benchmark/{name}/revenue": result.mean_revenue, - f"benchmark/{name}/coi_preserved": result.coi_preserved_pct, - f"benchmark/{name}/margin": result.margin_integrity, - "benchmark/alpha": alpha, - } - ) + try: + wandb.log( + { + f"benchmark/{name}/revenue": result.mean_revenue, + f"benchmark/{name}/coi_preserved": result.coi_preserved_pct, + f"benchmark/{name}/margin": result.margin_integrity, + "benchmark/alpha": alpha, + } + ) + except Exception: + pass return self.results diff --git a/engine/orchestrators/sweep_agent.py b/engine/orchestrators/sweep_agent.py index 9f3dcfc..6afeaa2 100644 --- a/engine/orchestrators/sweep_agent.py +++ b/engine/orchestrators/sweep_agent.py @@ -9,6 +9,7 @@ from ..telemetry.wandb import ( get_wandb_module, init_run, run_agent, + update_summary, ) from .train import run_with_active_sweep_run @@ -43,13 +44,23 @@ def run_sweep_agent( spec = TrainSpec.from_flat(merged) if run is not None: run.name = run_name(spec, kind=kind, scenario=scenario) - run_with_active_sweep_run( - spec, - kind=kind, - scenario=scenario, - group=group, - extra_tags=extra_tags, - ) + try: + run_with_active_sweep_run( + spec, + kind=kind, + scenario=scenario, + group=group, + extra_tags=extra_tags, + ) + update_summary({"run/status": "finished"}) + except Exception as exc: + update_summary( + { + "run/status": "crashed", + "run/error": str(exc), + } + ) + raise finally: finish_run() diff --git a/engine/orchestrators/train.py b/engine/orchestrators/train.py index 81ebdb5..4be8997 100644 --- a/engine/orchestrators/train.py +++ b/engine/orchestrators/train.py @@ -20,7 +20,7 @@ def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list kind, spec.algorithm.name, spec.runtime.backend, - "vanilla" if spec.study.no_robust else "robust", + "baseline" if spec.study.no_robust else "defended", ] tags.extend([tag for tag in extra_tags if tag]) return tags diff --git a/engine/project.json b/engine/project.json index 4d5d041..1fb18e4 100644 --- a/engine/project.json +++ b/engine/project.json @@ -91,6 +91,44 @@ "command": "bash scripts/nx_research.sh docker-train-publish", "cwd": "." } + }, + "whoclicked-publish": { + "executor": "nx:run-commands", + "dependsOn": [ + "install" + ], + "options": { + "command": "bash scripts/nx_research.sh whoclicked-publish", + "cwd": "." + } + }, + "tpu-ray-bootstrap": { + "executor": "nx:run-commands", + "options": { + "command": "bash scripts/nx_research.sh tpu-ray-bootstrap", + "cwd": "." + } + }, + "tpu-ray-deps": { + "executor": "nx:run-commands", + "options": { + "command": "bash scripts/nx_research.sh tpu-ray-deps", + "cwd": "." + } + }, + "tpu-ray-verify": { + "executor": "nx:run-commands", + "options": { + "command": "bash scripts/nx_research.sh tpu-ray-verify", + "cwd": "." + } + }, + "tpu-ray-teardown": { + "executor": "nx:run-commands", + "options": { + "command": "bash scripts/nx_research.sh tpu-ray-teardown", + "cwd": "." + } } }, "tags": [ diff --git a/engine/spec.py b/engine/spec.py index 5ddd0ce..8cc3ea9 100644 --- a/engine/spec.py +++ b/engine/spec.py @@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]: "study.robust_radius": "robust_radius", "study.robust_points": "robust_points", "study.robust_rollouts": "robust_rollouts", + "study.ambiguity_radius": "robust_radius", + "study.ambiguity_points": "robust_points", + "study.ambiguity_rollouts": "robust_rollouts", "study.info_value": "info_value", "study.eta_ux": "eta_ux", "study.reward_profit_weight": "reward_profit_weight", - "study.revenue_weight": "revenue_weight", + "ambiguity_radius": "robust_radius", + "ambiguity_points": "robust_points", + "ambiguity_rollouts": "robust_rollouts", + "baseline_mode": "no_robust", + "stress_eval_enabled": "robust_eval_enabled", "optimizer.learning_rate": "learning_rate", "optimizer.gamma": "gamma", "optimizer.batch_size": "batch_size", @@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]: "runtime.seed": "seed", "runtime.total_timesteps": "total_timesteps", "runtime.checkpoint_interval": "checkpoint_interval", + "runtime.hist_freq": "hist_freq", "eval.eval_freq": "eval_freq", "eval.eval_episodes": "eval_episodes", } @@ -86,7 +94,6 @@ class StudySpec: info_value: float = 1.0 eta_ux: float = 0.5 reward_profit_weight: float = 1.0 - revenue_weight: float = 0.01 no_robust: bool = False @@ -128,6 +135,7 @@ class RuntimeSpec: checkpoint_interval: int = 200_000 model_dir: str = "engine/models" log_freq: int = 100 + hist_freq: int = 500 @dataclass(frozen=True) @@ -159,6 +167,7 @@ class TrainSpec: "backend": self.runtime.backend, "device": self.runtime.device, "checkpoint_interval": self.runtime.checkpoint_interval, + "hist_freq": self.runtime.hist_freq, "n_products": self.env.n_products, "N": self.env.n_sessions, "price_low": self.env.price_low, @@ -179,7 +188,6 @@ class TrainSpec: "info_value": self.study.info_value, "eta_ux": self.study.eta_ux, "reward_profit_weight": self.study.reward_profit_weight, - "revenue_weight": self.study.revenue_weight, "no_robust": self.study.no_robust, "learning_rate": self.optimizer.learning_rate, "gamma": self.optimizer.gamma, @@ -262,7 +270,6 @@ class TrainSpec: info_value=float(base["info_value"]), eta_ux=float(base["eta_ux"]), reward_profit_weight=float(base["reward_profit_weight"]), - revenue_weight=float(base["revenue_weight"]), no_robust=no_robust, ), optimizer=OptimizerSpec( @@ -300,6 +307,7 @@ class TrainSpec: checkpoint_interval=int(base["checkpoint_interval"]), model_dir=str(base["model_dir"]), log_freq=int(base["log_freq"]), + hist_freq=int(base["hist_freq"]), ), eval=EvalSpec( eval_freq=int(base["eval_freq"]), @@ -310,9 +318,11 @@ class TrainSpec: def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str: + alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".") + mode = "baseline" if bool(spec.study.no_robust) else "defended" return ( f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/" - f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}" + f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}" ) @@ -324,6 +334,7 @@ def run_metadata( group: str | None = None, tags: Sequence[str] = (), ) -> dict[str, Any]: + mode = "baseline" if bool(spec.study.no_robust) else "defended" metadata: dict[str, Any] = { "run.kind": str(kind), "run.algo": spec.algorithm.name, @@ -332,6 +343,10 @@ def run_metadata( "run.scenario": str(scenario), "run.seed": spec.runtime.seed, "run.tags": list(tags), + "study/alpha": float(spec.study.alpha), + "study/mode": mode, + "study/baseline_mode": float(bool(spec.study.no_robust)), + "tiers": spec.algorithm.name, } if group: metadata["run.group"] = group diff --git a/engine/telemetry/metrics.py b/engine/telemetry/metrics.py index aa080d8..ccfea58 100644 --- a/engine/telemetry/metrics.py +++ b/engine/telemetry/metrics.py @@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A eval_reward = ( _as_float( - metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")), + metrics.get( + "eval/stress_reward_worst", + metrics.get( + "eval/robust_reward_worst", metrics.get("eval/reward_mean") + ), + ), 0.0, ) or 0.0 @@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level metrics["study/alpha"] = spec.study.alpha + metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended" + metrics["study/baseline_mode"] = float(bool(spec.study.no_robust)) metrics["study/lambda_coi"] = spec.study.lambda_coi - metrics["study/robust_radius"] = spec.study.robust_radius + metrics["study/ambiguity_radius"] = spec.study.robust_radius metrics["study/info_value"] = spec.study.info_value + metrics["tiers"] = spec.algorithm.name metrics["runtime/backend"] = spec.runtime.backend metrics["runtime/device"] = spec.runtime.device diff --git a/engine/telemetry/wandb.py b/engine/telemetry/wandb.py index 5e6fb85..4181a80 100644 --- a/engine/telemetry/wandb.py +++ b/engine/telemetry/wandb.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import time from typing import Any, Callable, Iterable, Mapping @@ -19,6 +21,42 @@ def _require_wandb(): return wandb +def _warn(message: str) -> None: + print(f"PHANTOM_WANDB_WARNING: {message}") + + +def _sanitize_key(raw_key: str) -> str | None: + key = str(raw_key) + replacements = { + "no_robust": "baseline_mode", + "study/no_robust": "study/baseline_mode", + "study/robust_radius": "study/ambiguity_radius", + "robust_radius": "ambiguity_radius", + "robust_points": "ambiguity_points", + "robust_rollouts": "ambiguity_rollouts", + "robust_eval_enabled": "stress_eval_enabled", + "eval/robust_alpha_high": "eval/stress_alpha_high", + "eval/robust_alpha_low": "eval/stress_alpha_low", + "eval/robust_reward_worst": "eval/stress_reward_worst", + "eval/robust_revenue_worst": "eval/stress_revenue_worst", + "eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst", + } + key = replacements.get(key, key) + if "robust" in key.lower(): + return None + return key + + +def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]: + sanitized: dict[str, Any] = {} + for key, value in payload.items(): + clean_key = _sanitize_key(str(key)) + if clean_key is None: + continue + sanitized[clean_key] = value + return sanitized + + def init_run( *, mode: str, @@ -34,7 +72,11 @@ def init_run( if group: kwargs["group"] = group if sweep_mode: - run = wandb.init(**kwargs) + try: + run = wandb.init(**kwargs) + except Exception as exc: + _warn(f"init failed in sweep mode ({exc})") + return None if name and run is not None: run.name = name return run @@ -42,18 +84,25 @@ def init_run( init_kwargs = dict(kwargs) init_kwargs["project"] = project if config is not None: - init_kwargs["config"] = dict(config) + init_kwargs["config"] = _sanitize_payload(dict(config)) if name: init_kwargs["name"] = name if tags: init_kwargs["tags"] = list(tags) - return wandb.init(**init_kwargs) + try: + return wandb.init(**init_kwargs) + except Exception as exc: + _warn(f"init failed ({exc})") + return None def finish_run() -> None: wandb = get_wandb_module() if wandb is not None and wandb.run is not None: - wandb.finish() + try: + wandb.finish() + except Exception as exc: + _warn(f"finish failed ({exc})") def current_config() -> dict[str, Any]: @@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None: wandb = get_wandb_module() if wandb is None or wandb.run is None: return + payload = _sanitize_payload(dict(config)) + if not payload: + return try: - wandb.config.update(dict(config), allow_val_change=True) + wandb.config.update(payload, allow_val_change=True) except TypeError: - wandb.config.update(dict(config)) + try: + wandb.config.update(payload) + except Exception as exc: + _warn(f"config update failed ({exc})") + except Exception as exc: + _warn(f"config update failed ({exc})") def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None: wandb = get_wandb_module() if wandb is None or wandb.run is None: return - wandb.log(dict(metrics), step=step) + payload = _sanitize_payload(dict(metrics)) + if not payload: + return + try: + wandb.log(payload, step=step) + except Exception as exc: + _warn(f"log failed at step {step} ({exc})") def update_summary(metrics: Mapping[str, Any]) -> None: wandb = get_wandb_module() if wandb is None or wandb.run is None: return - for key, value in metrics.items(): - wandb.run.summary[key] = value + payload = _sanitize_payload(dict(metrics)) + if not payload: + return + try: + for key, value in payload.items(): + wandb.run.summary[key] = value + except Exception as exc: + _warn(f"summary update failed ({exc})") def run_agent( @@ -95,4 +164,39 @@ def run_agent( count: int | None = None, ) -> None: wandb = _require_wandb() - wandb.agent(sweep_id, function=fn, count=count) + retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8"))) + retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5"))) + retry_backoff = max( + 1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5")) + ) + retry_max_delay = max( + retry_delay, + float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")), + ) + + target = None if count is None else max(0, int(count)) + completed = 0 + + def _wrapped() -> None: + nonlocal completed + fn() + completed += 1 + + attempt = 0 + while True: + remaining = None if target is None else max(0, int(target - completed)) + if target is not None and remaining == 0: + return + try: + wandb.agent(sweep_id, function=_wrapped, count=remaining) + return + except Exception as exc: + attempt += 1 + if attempt > retry_max: + raise + wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1))) + _warn( + f"agent disconnected (attempt {attempt}/{retry_max}, " + f"completed={completed}, remaining={remaining}): {exc}" + ) + time.sleep(wait) diff --git a/engine/train.py b/engine/train.py index aafd02c..3fc235d 100644 --- a/engine/train.py +++ b/engine/train.py @@ -54,6 +54,7 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--total-timesteps", type=int) parser.add_argument("--model-dir", type=str) parser.add_argument("--log-freq", type=int) + parser.add_argument("--hist-freq", type=int) parser.add_argument("--checkpoint-interval", type=int) parser.add_argument("--device", type=str) @@ -68,7 +69,6 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--no-robust", action="store_true") parser.add_argument("--eta-ux", type=float) parser.add_argument("--reward-profit-weight", type=float) - parser.add_argument("--revenue-weight", type=float) parser.add_argument("--price-low", type=float) parser.add_argument("--price-high", type=float) @@ -126,6 +126,7 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: "total_timesteps": args.total_timesteps, "model_dir": args.model_dir, "log_freq": args.log_freq, + "hist_freq": args.hist_freq, "checkpoint_interval": args.checkpoint_interval, "device": args.device, "alpha": args.alpha, @@ -139,7 +140,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: "no_robust": args.no_robust, "eta_ux": args.eta_ux, "reward_profit_weight": args.reward_profit_weight, - "revenue_weight": args.revenue_weight, "price_low": args.price_low, "price_high": args.price_high, "action_levels": args.action_levels,