updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -132,15 +132,15 @@ def evaluate(
shifted_env.close() shifted_env.close()
shifted_rows.append((tag, alpha, shifted_metrics)) shifted_rows.append((tag, alpha, shifted_metrics))
metrics["eval/robust_alpha_low"] = low_alpha metrics["eval/stress_alpha_low"] = low_alpha
metrics["eval/robust_alpha_high"] = high_alpha metrics["eval/stress_alpha_high"] = high_alpha
metrics["eval/robust_reward_worst"] = float( metrics["eval/stress_reward_worst"] = float(
min(row[2]["eval/reward_mean"] for row in shifted_rows) min(row[2]["eval/reward_mean"] for row in shifted_rows)
) )
metrics["eval/robust_revenue_worst"] = float( metrics["eval/stress_revenue_worst"] = float(
min(row[2]["eval/revenue_mean"] for row in shifted_rows) min(row[2]["eval/revenue_mean"] for row in shifted_rows)
) )
metrics["eval/robust_coi_leakage_worst"] = float( metrics["eval/stress_coi_leakage_worst"] = float(
max(row[2]["eval/coi_leakage_mean"] for row in shifted_rows) max(row[2]["eval/coi_leakage_mean"] for row in shifted_rows)
) )
for tag, alpha, shifted_metrics in shifted_rows: for tag, alpha, shifted_metrics in shifted_rows:

View File

@@ -80,7 +80,11 @@ def train_qtable(
"train/global_step": int(steps), "train/global_step": int(steps),
} }
if wandb_live: if wandb_live:
wandb.log(dict(event), step=step_offset + int(steps)) try:
wandb.log(dict(event), step=step_offset + int(steps))
except Exception:
wandb_live = False
train_events.append(event)
else: else:
train_events.append(event) train_events.append(event)
if console_progress: if console_progress:
@@ -113,7 +117,11 @@ def train_qtable(
"train/global_step": int(steps), "train/global_step": int(steps),
} }
if wandb_live: if wandb_live:
wandb.log(dict(tail_event), step=step_offset + int(steps)) try:
wandb.log(dict(tail_event), step=step_offset + int(steps))
except Exception:
wandb_live = False
train_events.append(tail_event)
else: else:
train_events.append(tail_event) train_events.append(tail_event)

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any, Mapping from typing import Any, Mapping
from ..lib.callbacks import MetricsCallback from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
from .common import evaluate, make_env from .common import evaluate, make_env
@@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any):
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
try: try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
except ImportError as exc: except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc raise ImportError("stable-baselines3 is required for SB3 models") from exc
@@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
pass pass
metrics_callback = MetricsCallback( metrics_callback = MetricsCallback(
log_histograms=False, log_histograms=True,
log_freq=int(cfg["log_freq"]), log_freq=int(cfg["log_freq"]),
hist_freq=int(cfg.get("hist_freq", 500)),
step_offset=int(cfg.get("wandb_step_offset", 0)), step_offset=int(cfg.get("wandb_step_offset", 0)),
) )
callbacks = [metrics_callback] eval_callback = EvalMetricsCallback(
callbacks.append( eval_env,
EvalCallback( eval_freq=int(cfg["eval_freq"]),
eval_env, n_eval_episodes=int(cfg["eval_episodes"]),
eval_freq=int(cfg["eval_freq"]), step_offset=int(cfg.get("wandb_step_offset", 0)),
n_eval_episodes=int(cfg["eval_episodes"]), deterministic=True,
deterministic=True, verbose=0,
verbose=0,
)
) )
callbacks = [metrics_callback, eval_callback]
target_steps = int(cfg["total_timesteps"]) target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0))) remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
@@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
model_path = model_dir / f"phantom_{cfg['algo']}" model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path)) model.save(str(model_path))
artifact_name = checkpoint_artifact_name(
cfg,
backend="sb3",
sweep_id=os.getenv("WANDB_SWEEP_ID"),
)
artifact_logged = False
try:
artifact_logged = bool(
log_checkpoint_file(
artifact_name,
file_path=model_path.with_suffix(".zip"),
artifact_file_name="model.zip",
metadata={
"algo": str(cfg.get("algo", "ppo")),
"backend": "sb3",
"seed": int(cfg.get("seed", 0)),
"step": int(getattr(model, "num_timesteps", 0)),
},
)
)
except Exception:
artifact_logged = False
metrics: dict[str, Any] = evaluate( metrics: dict[str, Any] = evaluate(
model, model,
eval_env, eval_env,
@@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
) )
metrics["train/global_step"] = int(model.num_timesteps) metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip")) metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["_train_events"] = list(metrics_callback.events) metrics["model/artifact_name"] = str(artifact_name)
metrics["model/artifact_logged"] = float(artifact_logged)
metrics["_train_events"] = sorted(
[*metrics_callback.events, *eval_callback.events],
key=lambda event: int(event.get("train/global_step", 0)),
)
env.close() env.close()
eval_env.close() eval_env.close()

View File

@@ -45,6 +45,10 @@ def _log(message: str) -> None:
logger.info(message) logger.info(message)
def _wandb_run_active() -> bool:
return bool(HAS_WANDB and getattr(wandb, "run", None) is not None)
def _parse_list(raw: str) -> list[str]: def _parse_list(raw: str) -> list[str]:
return [x.strip().lower() for x in str(raw).split(",") if x.strip()] return [x.strip().lower() for x in str(raw).split(",") if x.strip()]
@@ -61,6 +65,10 @@ def _truthy(value: str | bool | None) -> bool:
return str(value).strip().lower() in {"1", "true", "yes", "on"} return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _mode_label_from_baseline(is_baseline: bool) -> str:
return "baseline" if bool(is_baseline) else "defended"
def _action(policy, obs: np.ndarray): def _action(policy, obs: np.ndarray):
out = policy.predict(obs, deterministic=True) out = policy.predict(obs, deterministic=True)
action = out[0] if isinstance(out, tuple) else out action = out[0] if isinstance(out, tuple) else out
@@ -166,7 +174,7 @@ def _log_train_events(
alpha: float, alpha: float,
step_offset: int, step_offset: int,
) -> int: ) -> int:
if not (HAS_WANDB and wandb.run is not None): if not _wandb_run_active():
return int(step_offset) return int(step_offset)
if not events: if not events:
return int(step_offset) return int(step_offset)
@@ -187,11 +195,14 @@ def _log_train_events(
"run.kind": "benchmark", "run.kind": "benchmark",
"runtime/backend": tier_name, "runtime/backend": tier_name,
"study/mode": mode_label, "study/mode": mode_label,
"study/no_robust": float(mode_label == "no_robust"), "study/baseline_mode": float(mode_label == "baseline"),
"study/alpha": float(alpha), "study/alpha": float(alpha),
} }
) )
wandb.log(payload, step=cursor + rel_step) try:
wandb.log(payload, step=cursor + rel_step)
except Exception:
return int(step_offset)
max_rel = max(max(1, int(evt.get("train/global_step", 0))) for evt in ordered) max_rel = max(max(1, int(evt.get("train/global_step", 0))) for evt in ordered)
return cursor + max_rel + 1 return cursor + max_rel + 1
@@ -203,6 +214,7 @@ def run_benchmark(
n_episodes: int, n_episodes: int,
mode_label: str, mode_label: str,
step_cursor_start: int = 0, step_cursor_start: int = 0,
eval_alpha_values: list[float] | None = None,
): ):
from .backends.common import make_env from .backends.common import make_env
@@ -239,62 +251,80 @@ def run_benchmark(
"dqn", "dqn",
}: }:
wandb_step_cursor += max(1, int(cfg.get("total_timesteps", 1))) + 1 wandb_step_cursor += max(1, int(cfg.get("total_timesteps", 1))) + 1
env = make_env({**cfg, "alpha": float(alpha)}) eval_targets = (
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))] [float(value) for value in eval_alpha_values]
env.close() if eval_alpha_values
else [float(alpha)]
row = {
"tier": tier_name,
"mode": mode_label,
"alpha": float(alpha),
"episodes": int(n_episodes),
"mean_reward": float(np.mean([e["reward"] for e in eps])),
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
"mean_margin": float(np.mean([e["mean_margin"] for e in eps])),
"mean_coi": float(np.mean([e["mean_coi"] for e in eps])),
"std_revenue": float(np.std([e["revenue"] for e in eps])),
}
row["objective_score"] = row["mean_reward"]
rows.append(row)
_log(
f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: "
f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} "
f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}"
) )
for eval_alpha in eval_targets:
env = make_env({**cfg, "alpha": float(eval_alpha)})
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
env.close()
max_len = max((len(e["price_trace"]) for e in eps), default=0) row = {
step_means = []
for step in range(max_len):
vals = [
e["price_trace"][step] for e in eps if step < len(e["price_trace"])
]
step_means.append(float(np.mean(vals)) if vals else np.nan)
traces.append(
{
"tier": tier_name, "tier": tier_name,
"alpha": float(alpha), "mode": mode_label,
"mean_price_trace": step_means, "alpha": float(eval_alpha),
"train_alpha": float(alpha),
"eval_alpha": float(eval_alpha),
"episodes": int(n_episodes),
"mean_reward": float(np.mean([e["reward"] for e in eps])),
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
"mean_margin": float(np.mean([e["mean_margin"] for e in eps])),
"mean_coi": float(np.mean([e["mean_coi"] for e in eps])),
"std_revenue": float(np.std([e["revenue"] for e in eps])),
} }
) row["objective_score"] = row["mean_reward"]
rows.append(row)
if HAS_WANDB and wandb.run is not None: _log(
wandb.log( f"[{run_index}/{total_runs}] train_alpha={float(alpha):.2f} "
{ f"eval_alpha={float(eval_alpha):.2f} tier={tier_name}: "
"run.kind": "benchmark", f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} "
"runtime/backend": tier_name, f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}"
"study/mode": mode_label,
"study/no_robust": float(mode_label == "no_robust"),
"study/alpha": float(alpha),
"eval/reward_mean": row["mean_reward"],
"eval/revenue_mean": row["mean_revenue"],
"eval/margin_mean": row["mean_margin"],
"eval/coi_level_mean": row["mean_coi"],
"objective/score": row["objective_score"],
"objective/coi_preserved": row["mean_coi"],
},
step=wandb_step_cursor,
) )
wandb_step_cursor += 1
max_len = max((len(e["price_trace"]) for e in eps), default=0)
step_means = []
for step in range(max_len):
vals = [
e["price_trace"][step]
for e in eps
if step < len(e["price_trace"])
]
step_means.append(float(np.mean(vals)) if vals else np.nan)
traces.append(
{
"tier": tier_name,
"alpha": float(eval_alpha),
"train_alpha": float(alpha),
"eval_alpha": float(eval_alpha),
"mean_price_trace": step_means,
}
)
if _wandb_run_active():
try:
wandb.log(
{
"run.kind": "benchmark",
"runtime/backend": tier_name,
"study/mode": mode_label,
"study/baseline_mode": float(mode_label == "baseline"),
"study/alpha": float(eval_alpha),
"study/train_alpha": float(alpha),
"study/eval_alpha": float(eval_alpha),
"eval/reward_mean": row["mean_reward"],
"eval/revenue_mean": row["mean_revenue"],
"eval/margin_mean": row["mean_margin"],
"eval/coi_level_mean": row["mean_coi"],
"objective/score": row["objective_score"],
"objective/coi_preserved": row["mean_coi"],
},
step=wandb_step_cursor,
)
except Exception:
pass
wandb_step_cursor += 1
return pd.DataFrame(rows), traces, int(wandb_step_cursor) return pd.DataFrame(rows), traces, int(wandb_step_cursor)
@@ -378,7 +408,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
if compare_robust_override is not None if compare_robust_override is not None
else _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) else _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
) )
robust_modes = [False, True] if compare_robust else [bool(args.no_robust)] baseline_modes = [False, True] if compare_robust else [bool(args.no_robust)]
base_overrides = { base_overrides = {
"seed": args.seed, "seed": args.seed,
@@ -389,6 +419,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
"robust_radius": args.robust_radius, "robust_radius": args.robust_radius,
"robust_points": args.robust_points, "robust_points": args.robust_points,
"robust_rollouts": args.robust_rollouts, "robust_rollouts": args.robust_rollouts,
"margin_floor": args.margin_floor,
"eta_ux": args.eta_ux, "eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight, "reward_profit_weight": args.reward_profit_weight,
"price_low": args.price_low, "price_low": args.price_low,
@@ -405,12 +436,20 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
} }
tiers = _parse_list(args.tiers) tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values) alpha_values = _parse_float_list(args.alpha_values)
eval_alpha_values = (
_parse_float_list(args.eval_alpha_values)
if str(getattr(args, "eval_alpha_values", "")).strip()
else []
)
_log( _log(
"starting run " "starting run "
+ json.dumps( + json.dumps(
{ {
"tiers": tiers, "tiers": tiers,
"alpha_values": alpha_values, "alpha_values": alpha_values,
"eval_alpha_values": (
eval_alpha_values if eval_alpha_values else alpha_values
),
"episodes": int(args.episodes), "episodes": int(args.episodes),
"total_timesteps": int(args.total_timesteps), "total_timesteps": int(args.total_timesteps),
"device": str(args.device), "device": str(args.device),
@@ -421,14 +460,14 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
all_frames: list[pd.DataFrame] = [] all_frames: list[pd.DataFrame] = []
all_traces: list[dict] = [] all_traces: list[dict] = []
wandb_step_cursor = 0 wandb_step_cursor = 0
for no_robust in robust_modes: for baseline_mode in baseline_modes:
overrides = dict(base_overrides) overrides = dict(base_overrides)
overrides["no_robust"] = bool(no_robust) overrides["baseline_mode"] = bool(baseline_mode)
cfg = TrainSpec.from_flat( cfg = TrainSpec.from_flat(
{k: v for k, v in overrides.items() if v is not None} {k: v for k, v in overrides.items() if v is not None}
).to_flat_dict() ).to_flat_dict()
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps) cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
mode_label = "no_robust" if no_robust else "robust" mode_label = _mode_label_from_baseline(bool(baseline_mode))
_log(f"mode={mode_label}: begin") _log(f"mode={mode_label}: begin")
df_mode, traces_mode, wandb_step_cursor = run_benchmark( df_mode, traces_mode, wandb_step_cursor = run_benchmark(
cfg, cfg,
@@ -437,6 +476,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
args.episodes, args.episodes,
mode_label=mode_label, mode_label=mode_label,
step_cursor_start=wandb_step_cursor, step_cursor_start=wandb_step_cursor,
eval_alpha_values=eval_alpha_values,
) )
_log(f"mode={mode_label}: complete ({len(df_mode)} rows)") _log(f"mode={mode_label}: complete ({len(df_mode)} rows)")
for trace in traces_mode: for trace in traces_mode:
@@ -465,7 +505,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
+ json.dumps( + json.dumps(
{ {
"tier": best["tier"], "tier": best["tier"],
"mode": best.get("mode", "robust"), "mode": best.get("mode", "defended"),
"alpha": float(best["alpha"]), "alpha": float(best["alpha"]),
"objective_score": float(best["objective_score"]), "objective_score": float(best["objective_score"]),
"mean_revenue": float(best["mean_revenue"]), "mean_revenue": float(best["mean_revenue"]),
@@ -486,6 +526,7 @@ def run_cli(raw_args: list[str] | None = None):
parser.add_argument("--project", default="capstone") parser.add_argument("--project", default="capstone")
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo") parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")
parser.add_argument("--alpha-values", default="0.0,0.3,0.6") parser.add_argument("--alpha-values", default="0.0,0.3,0.6")
parser.add_argument("--eval-alpha-values", default="")
parser.add_argument("--episodes", type=int, default=10) parser.add_argument("--episodes", type=int, default=10)
parser.add_argument("--output-dir", default="engine/studies/results") parser.add_argument("--output-dir", default="engine/studies/results")
parser.add_argument("--seed", type=int, default=42) parser.add_argument("--seed", type=int, default=42)
@@ -496,6 +537,7 @@ def run_cli(raw_args: list[str] | None = None):
parser.add_argument("--robust-radius", type=float, default=0.15) parser.add_argument("--robust-radius", type=float, default=0.15)
parser.add_argument("--robust-points", type=int, default=5) parser.add_argument("--robust-points", type=int, default=5)
parser.add_argument("--robust-rollouts", type=int, default=1) parser.add_argument("--robust-rollouts", type=int, default=1)
parser.add_argument("--margin-floor", type=float, default=0.85)
parser.add_argument("--eta-ux", type=float, default=0.5) parser.add_argument("--eta-ux", type=float, default=0.5)
parser.add_argument("--reward-profit-weight", type=float, default=1.0) parser.add_argument("--reward-profit-weight", type=float, default=1.0)
parser.add_argument("--price-low", type=float, default=10.0) parser.add_argument("--price-low", type=float, default=10.0)
@@ -529,35 +571,47 @@ def run_cli(raw_args: list[str] | None = None):
key_to_attr = { key_to_attr = {
"tiers": "tiers", "tiers": "tiers",
"alpha_values": "alpha_values", "alpha_values": "alpha_values",
"eval_alpha_values": "eval_alpha_values",
"episodes": "episodes", "episodes": "episodes",
"total_timesteps": "total_timesteps", "total_timesteps": "total_timesteps",
"lambda_coi": "lambda_coi", "lambda_coi": "lambda_coi",
"robust_radius": "robust_radius", "robust_radius": "robust_radius",
"robust_points": "robust_points", "robust_points": "robust_points",
"robust_rollouts": "robust_rollouts", "robust_rollouts": "robust_rollouts",
"ambiguity_radius": "robust_radius",
"ambiguity_points": "robust_points",
"ambiguity_rollouts": "robust_rollouts",
"eta_ux": "eta_ux", "eta_ux": "eta_ux",
"reward_profit_weight": "reward_profit_weight", "reward_profit_weight": "reward_profit_weight",
"learning_rate": "learning_rate", "learning_rate": "learning_rate",
"batch_size": "batch_size", "batch_size": "batch_size",
"n_steps": "n_steps", "n_steps": "n_steps",
"baseline_mode": "no_robust",
"no_robust": "no_robust", "no_robust": "no_robust",
"margin_floor": "margin_floor",
"device": "device", "device": "device",
} }
for key in ( for key in (
"tiers", "tiers",
"alpha_values", "alpha_values",
"eval_alpha_values",
"episodes", "episodes",
"total_timesteps", "total_timesteps",
"lambda_coi", "lambda_coi",
"robust_radius", "robust_radius",
"robust_points", "robust_points",
"robust_rollouts", "robust_rollouts",
"ambiguity_radius",
"ambiguity_points",
"ambiguity_rollouts",
"eta_ux", "eta_ux",
"reward_profit_weight", "reward_profit_weight",
"learning_rate", "learning_rate",
"batch_size", "batch_size",
"n_steps", "n_steps",
"baseline_mode",
"no_robust", "no_robust",
"margin_floor",
"device", "device",
): ):
if key in wandb.config: if key in wandb.config:
@@ -582,16 +636,16 @@ def run_cli(raw_args: list[str] | None = None):
alpha_values = _parse_float_list(args.alpha_values) alpha_values = _parse_float_list(args.alpha_values)
run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S") run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S")
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
compare_tag = "robust-compare" if compare_enabled else "single-mode" compare_tag = "defended-compare" if compare_enabled else "single-mode"
modes = ( modes = (
[("no_robust", True), ("robust", False)] [("baseline", True), ("defended", False)]
if compare_enabled if compare_enabled
else [("no_robust" if bool(args.no_robust) else "robust", bool(args.no_robust))] else [(_mode_label_from_baseline(bool(args.no_robust)), bool(args.no_robust))]
) )
run_idx = 0 run_idx = 0
for tier in tiers: for tier in tiers:
for mode_label, no_robust in modes: for mode_label, baseline_mode in modes:
for alpha in alpha_values: for alpha in alpha_values:
run_idx += 1 run_idx += 1
alpha_token = ( alpha_token = (
@@ -600,7 +654,7 @@ def run_cli(raw_args: list[str] | None = None):
tier_args = argparse.Namespace(**vars(args)) tier_args = argparse.Namespace(**vars(args))
tier_args.tiers = tier tier_args.tiers = tier
tier_args.alpha_values = str(float(alpha)) tier_args.alpha_values = str(float(alpha))
tier_args.no_robust = bool(no_robust) tier_args.no_robust = bool(baseline_mode)
run = wandb.init( run = wandb.init(
project=args.project, project=args.project,
name=( name=(
@@ -617,16 +671,19 @@ def run_cli(raw_args: list[str] | None = None):
"run.kind": "benchmark", "run.kind": "benchmark",
"runtime/backend": tier, "runtime/backend": tier,
"study/mode": mode_label, "study/mode": mode_label,
"study/no_robust": float(no_robust), "study/baseline_mode": float(baseline_mode),
"study/alpha": float(alpha), "study/alpha": float(alpha),
"tiers": tier, "tiers": tier,
"alpha_values": str(float(alpha)), "alpha_values": str(float(alpha)),
"eval_alpha_values": args.eval_alpha_values,
"episodes": args.episodes, "episodes": args.episodes,
"total_timesteps": args.total_timesteps, "total_timesteps": args.total_timesteps,
"lambda_coi": args.lambda_coi, "lambda_coi": args.lambda_coi,
"robust_radius": args.robust_radius, "ambiguity_radius": args.robust_radius,
"robust_points": args.robust_points, "ambiguity_points": args.robust_points,
"robust_rollouts": args.robust_rollouts, "ambiguity_rollouts": args.robust_rollouts,
"margin_floor": args.margin_floor,
"baseline_mode": float(baseline_mode),
"eta_ux": args.eta_ux, "eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight, "reward_profit_weight": args.reward_profit_weight,
"learning_rate": args.learning_rate, "learning_rate": args.learning_rate,

View File

@@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback):
self, self,
log_histograms: bool = False, log_histograms: bool = False,
log_freq: int = 100, log_freq: int = 100,
hist_freq: int = 500,
step_offset: int = 0, step_offset: int = 0,
verbose: int = 0, verbose: int = 0,
): ):
super().__init__(verbose) super().__init__(verbose)
self.log_histograms = log_histograms self.log_histograms = log_histograms
self.log_freq = max(1, int(log_freq)) self.log_freq = max(1, int(log_freq))
self.hist_freq = max(1, int(hist_freq))
self.step_offset = max(0, int(step_offset)) self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module() self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None) self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._price_samples: list[float] = []
self._demand_samples: list[float] = []
self._window_sums = { self._window_sums = {
"train/revenue_mean": 0.0, "train/revenue_mean": 0.0,
"train/margin_mean": 0.0, "train/margin_mean": 0.0,
@@ -74,35 +78,100 @@ class MetricsCallback(BaseCallback):
) )
self._window_count += 1 self._window_count += 1
def _flush(self, step: int) -> None: def _accumulate_histograms(self, info: dict[str, Any]) -> None:
if self._window_count <= 0: if not self.log_histograms:
return return
denom = float(self._window_count)
payload = { for key in ("effective_prices", "prices"):
key: (value / denom) if key not in info:
for key, value in self._window_sums.items() continue
if value != 0.0 try:
or key values = np.asarray(info.get(key), dtype=float).reshape(-1)
in { except Exception:
"train/revenue_mean", continue
"train/margin_mean", if values.size <= 0:
"train/coi_level_mean", continue
"train/regret_mean", finite_values = values[np.isfinite(values)]
if finite_values.size > 0:
self._price_samples.extend(finite_values.tolist())
break
if "demand" in info:
try:
demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1)
except Exception:
demand_values = np.array([], dtype=float)
if demand_values.size > 0:
finite_demand = demand_values[np.isfinite(demand_values)]
if finite_demand.size > 0:
self._demand_samples.extend(finite_demand.tolist())
def _flush_histograms(self, step: int, force: bool = False) -> None:
if not self.log_histograms:
return
if not force and step % self.hist_freq != 0:
return
if not self._price_samples and not self._demand_samples:
return
if self._wandb is None:
self._price_samples.clear()
self._demand_samples.clear()
return
payload: dict[str, Any] = {}
if self._price_samples:
payload["train/price_dist"] = self._wandb.Histogram(
np.asarray(self._price_samples, dtype=np.float32)
)
if self._demand_samples:
payload["train/demand_dist"] = self._wandb.Histogram(
np.asarray(self._demand_samples, dtype=np.float32)
)
if payload and self._wandb_live:
try:
self._wandb.log(payload, step=self.step_offset + int(step))
except Exception:
self._wandb_live = False
self._price_samples.clear()
self._demand_samples.clear()
def _flush(self, step: int, *, force_hist: bool = False) -> None:
if self._window_count > 0:
denom = float(self._window_count)
payload = {
key: (value / denom)
for key, value in self._window_sums.items()
if value != 0.0
or key
in {
"train/revenue_mean",
"train/margin_mean",
"train/coi_level_mean",
"train/regret_mean",
}
} }
} payload["train/global_step"] = int(step)
payload["train/global_step"] = int(step) if self._wandb_live:
if self._wandb_live: try:
self._wandb.log(dict(payload), step=self.step_offset + int(step)) self._wandb.log(dict(payload), step=self.step_offset + int(step))
else: except Exception:
self.events.append(payload) self._wandb_live = False
for key in self._window_sums: self.events.append(payload)
self._window_sums[key] = 0.0 else:
self._window_count = 0 self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0
self._flush_histograms(step=step, force=force_hist)
def _on_step(self) -> bool: def _on_step(self) -> bool:
for info in self.locals.get("infos", []): for info in self.locals.get("infos", []):
if isinstance(info, dict): if isinstance(info, dict):
self._accumulate(info) self._accumulate(info)
self._accumulate_histograms(info)
if self.num_timesteps % self.log_freq == 0: if self.num_timesteps % self.log_freq == 0:
self._flush(step=self.num_timesteps) self._flush(step=self.num_timesteps)
@@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback):
return True return True
def _on_training_end(self) -> None: def _on_training_end(self) -> None:
self._flush(step=self.num_timesteps) self._flush(step=self.num_timesteps, force_hist=True)
class EvalMetricsCallback(EvalCallback): class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation collector detached from logging backends.""" """Deterministic evaluation collector detached from logging backends."""
def __init__( def __init__(
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs self,
eval_env,
eval_freq: int = 1000,
n_eval_episodes: int = 5,
step_offset: int = 0,
**kwargs,
): ):
super().__init__( super().__init__(
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
) )
self._eval_revenues: list[float] = [] self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._eval_stats: dict[str, list[float]] = {
"eval/revenue_mean": [],
"eval/margin_mean": [],
"eval/coi_level_mean": [],
"eval/coi_leakage_mean": [],
"eval/volatility_mean": [],
"eval/agent_prob_mean": [],
}
self.events: list[dict[str, float | int]] = [] self.events: list[dict[str, float | int]] = []
def _on_step(self) -> bool: def _on_step(self) -> bool:
result = super()._on_step() result = super()._on_step()
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
self.events.append( payload: dict[str, float | int] = {
{ "eval/reward_mean": float(self.last_mean_reward),
"eval/reward_mean": float(self.last_mean_reward), "train/global_step": int(self.num_timesteps),
"eval/revenue_mean": float(np.mean(self._eval_revenues)) }
if self._eval_revenues for key, values in self._eval_stats.items():
else 0.0, payload[key] = float(np.mean(values)) if values else 0.0
"train/global_step": int(self.num_timesteps),
} if self._wandb_live:
) try:
self._eval_revenues = [] self._wandb.log(
dict(payload),
step=self.step_offset + int(self.num_timesteps),
)
except Exception:
self._wandb_live = False
self.events.append(payload)
else:
self.events.append(payload)
for values in self._eval_stats.values():
values.clear()
return result return result
def _log_success_callback(self, locals_: dict, globals_: dict) -> None: def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
# called after each eval episode # called after each eval episode
info = locals_.get("info", {}) info = locals_.get("info", {})
if "economics" in info: econ = info.get("economics") if isinstance(info, dict) else None
self._eval_revenues.append(info["economics"]["revenue"]) if not isinstance(econ, dict):
return
self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0)))
self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0)))
self._eval_stats["eval/coi_level_mean"].append(
float(econ.get("coi_level", 0.0))
)
self._eval_stats["eval/coi_leakage_mean"].append(
float(econ.get("coi_leakage", 0.0))
)
self._eval_stats["eval/volatility_mean"].append(
float(econ.get("volatility", 0.0))
)
self._eval_stats["eval/agent_prob_mean"].append(
float(econ.get("agent_prob", 0.0))
)

View File

@@ -156,14 +156,17 @@ class ProviderBenchmark:
# log to wandb if available # log to wandb if available
if HAS_WANDB and wandb.run is not None: if HAS_WANDB and wandb.run is not None:
wandb.log( try:
{ wandb.log(
f"benchmark/{name}/revenue": result.mean_revenue, {
f"benchmark/{name}/coi_preserved": result.coi_preserved_pct, f"benchmark/{name}/revenue": result.mean_revenue,
f"benchmark/{name}/margin": result.margin_integrity, f"benchmark/{name}/coi_preserved": result.coi_preserved_pct,
"benchmark/alpha": alpha, f"benchmark/{name}/margin": result.margin_integrity,
} "benchmark/alpha": alpha,
) }
)
except Exception:
pass
return self.results return self.results

View File

@@ -9,6 +9,7 @@ from ..telemetry.wandb import (
get_wandb_module, get_wandb_module,
init_run, init_run,
run_agent, run_agent,
update_summary,
) )
from .train import run_with_active_sweep_run from .train import run_with_active_sweep_run
@@ -43,13 +44,23 @@ def run_sweep_agent(
spec = TrainSpec.from_flat(merged) spec = TrainSpec.from_flat(merged)
if run is not None: if run is not None:
run.name = run_name(spec, kind=kind, scenario=scenario) run.name = run_name(spec, kind=kind, scenario=scenario)
run_with_active_sweep_run( try:
spec, run_with_active_sweep_run(
kind=kind, spec,
scenario=scenario, kind=kind,
group=group, scenario=scenario,
extra_tags=extra_tags, group=group,
) extra_tags=extra_tags,
)
update_summary({"run/status": "finished"})
except Exception as exc:
update_summary(
{
"run/status": "crashed",
"run/error": str(exc),
}
)
raise
finally: finally:
finish_run() finish_run()

View File

@@ -20,7 +20,7 @@ def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list
kind, kind,
spec.algorithm.name, spec.algorithm.name,
spec.runtime.backend, spec.runtime.backend,
"vanilla" if spec.study.no_robust else "robust", "baseline" if spec.study.no_robust else "defended",
] ]
tags.extend([tag for tag in extra_tags if tag]) tags.extend([tag for tag in extra_tags if tag])
return tags return tags

View File

@@ -91,6 +91,44 @@
"command": "bash scripts/nx_research.sh docker-train-publish", "command": "bash scripts/nx_research.sh docker-train-publish",
"cwd": "." "cwd": "."
} }
},
"whoclicked-publish": {
"executor": "nx:run-commands",
"dependsOn": [
"install"
],
"options": {
"command": "bash scripts/nx_research.sh whoclicked-publish",
"cwd": "."
}
},
"tpu-ray-bootstrap": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-bootstrap",
"cwd": "."
}
},
"tpu-ray-deps": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-deps",
"cwd": "."
}
},
"tpu-ray-verify": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-verify",
"cwd": "."
}
},
"tpu-ray-teardown": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh tpu-ray-teardown",
"cwd": "."
}
} }
}, },
"tags": [ "tags": [

View File

@@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"study.robust_radius": "robust_radius", "study.robust_radius": "robust_radius",
"study.robust_points": "robust_points", "study.robust_points": "robust_points",
"study.robust_rollouts": "robust_rollouts", "study.robust_rollouts": "robust_rollouts",
"study.ambiguity_radius": "robust_radius",
"study.ambiguity_points": "robust_points",
"study.ambiguity_rollouts": "robust_rollouts",
"study.info_value": "info_value", "study.info_value": "info_value",
"study.eta_ux": "eta_ux", "study.eta_ux": "eta_ux",
"study.reward_profit_weight": "reward_profit_weight", "study.reward_profit_weight": "reward_profit_weight",
"study.revenue_weight": "revenue_weight", "ambiguity_radius": "robust_radius",
"ambiguity_points": "robust_points",
"ambiguity_rollouts": "robust_rollouts",
"baseline_mode": "no_robust",
"stress_eval_enabled": "robust_eval_enabled",
"optimizer.learning_rate": "learning_rate", "optimizer.learning_rate": "learning_rate",
"optimizer.gamma": "gamma", "optimizer.gamma": "gamma",
"optimizer.batch_size": "batch_size", "optimizer.batch_size": "batch_size",
@@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"runtime.seed": "seed", "runtime.seed": "seed",
"runtime.total_timesteps": "total_timesteps", "runtime.total_timesteps": "total_timesteps",
"runtime.checkpoint_interval": "checkpoint_interval", "runtime.checkpoint_interval": "checkpoint_interval",
"runtime.hist_freq": "hist_freq",
"eval.eval_freq": "eval_freq", "eval.eval_freq": "eval_freq",
"eval.eval_episodes": "eval_episodes", "eval.eval_episodes": "eval_episodes",
} }
@@ -86,7 +94,6 @@ class StudySpec:
info_value: float = 1.0 info_value: float = 1.0
eta_ux: float = 0.5 eta_ux: float = 0.5
reward_profit_weight: float = 1.0 reward_profit_weight: float = 1.0
revenue_weight: float = 0.01
no_robust: bool = False no_robust: bool = False
@@ -128,6 +135,7 @@ class RuntimeSpec:
checkpoint_interval: int = 200_000 checkpoint_interval: int = 200_000
model_dir: str = "engine/models" model_dir: str = "engine/models"
log_freq: int = 100 log_freq: int = 100
hist_freq: int = 500
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -159,6 +167,7 @@ class TrainSpec:
"backend": self.runtime.backend, "backend": self.runtime.backend,
"device": self.runtime.device, "device": self.runtime.device,
"checkpoint_interval": self.runtime.checkpoint_interval, "checkpoint_interval": self.runtime.checkpoint_interval,
"hist_freq": self.runtime.hist_freq,
"n_products": self.env.n_products, "n_products": self.env.n_products,
"N": self.env.n_sessions, "N": self.env.n_sessions,
"price_low": self.env.price_low, "price_low": self.env.price_low,
@@ -179,7 +188,6 @@ class TrainSpec:
"info_value": self.study.info_value, "info_value": self.study.info_value,
"eta_ux": self.study.eta_ux, "eta_ux": self.study.eta_ux,
"reward_profit_weight": self.study.reward_profit_weight, "reward_profit_weight": self.study.reward_profit_weight,
"revenue_weight": self.study.revenue_weight,
"no_robust": self.study.no_robust, "no_robust": self.study.no_robust,
"learning_rate": self.optimizer.learning_rate, "learning_rate": self.optimizer.learning_rate,
"gamma": self.optimizer.gamma, "gamma": self.optimizer.gamma,
@@ -262,7 +270,6 @@ class TrainSpec:
info_value=float(base["info_value"]), info_value=float(base["info_value"]),
eta_ux=float(base["eta_ux"]), eta_ux=float(base["eta_ux"]),
reward_profit_weight=float(base["reward_profit_weight"]), reward_profit_weight=float(base["reward_profit_weight"]),
revenue_weight=float(base["revenue_weight"]),
no_robust=no_robust, no_robust=no_robust,
), ),
optimizer=OptimizerSpec( optimizer=OptimizerSpec(
@@ -300,6 +307,7 @@ class TrainSpec:
checkpoint_interval=int(base["checkpoint_interval"]), checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]), model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]), log_freq=int(base["log_freq"]),
hist_freq=int(base["hist_freq"]),
), ),
eval=EvalSpec( eval=EvalSpec(
eval_freq=int(base["eval_freq"]), eval_freq=int(base["eval_freq"]),
@@ -310,9 +318,11 @@ class TrainSpec:
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str: def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".")
mode = "baseline" if bool(spec.study.no_robust) else "defended"
return ( return (
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/" f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}" f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}"
) )
@@ -324,6 +334,7 @@ def run_metadata(
group: str | None = None, group: str | None = None,
tags: Sequence[str] = (), tags: Sequence[str] = (),
) -> dict[str, Any]: ) -> dict[str, Any]:
mode = "baseline" if bool(spec.study.no_robust) else "defended"
metadata: dict[str, Any] = { metadata: dict[str, Any] = {
"run.kind": str(kind), "run.kind": str(kind),
"run.algo": spec.algorithm.name, "run.algo": spec.algorithm.name,
@@ -332,6 +343,10 @@ def run_metadata(
"run.scenario": str(scenario), "run.scenario": str(scenario),
"run.seed": spec.runtime.seed, "run.seed": spec.runtime.seed,
"run.tags": list(tags), "run.tags": list(tags),
"study/alpha": float(spec.study.alpha),
"study/mode": mode,
"study/baseline_mode": float(bool(spec.study.no_robust)),
"tiers": spec.algorithm.name,
} }
if group: if group:
metadata["run.group"] = group metadata["run.group"] = group

View File

@@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
eval_reward = ( eval_reward = (
_as_float( _as_float(
metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")), metrics.get(
"eval/stress_reward_worst",
metrics.get(
"eval/robust_reward_worst", metrics.get("eval/reward_mean")
),
),
0.0, 0.0,
) )
or 0.0 or 0.0
@@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
metrics["study/alpha"] = spec.study.alpha metrics["study/alpha"] = spec.study.alpha
metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended"
metrics["study/baseline_mode"] = float(bool(spec.study.no_robust))
metrics["study/lambda_coi"] = spec.study.lambda_coi metrics["study/lambda_coi"] = spec.study.lambda_coi
metrics["study/robust_radius"] = spec.study.robust_radius metrics["study/ambiguity_radius"] = spec.study.robust_radius
metrics["study/info_value"] = spec.study.info_value metrics["study/info_value"] = spec.study.info_value
metrics["tiers"] = spec.algorithm.name
metrics["runtime/backend"] = spec.runtime.backend metrics["runtime/backend"] = spec.runtime.backend
metrics["runtime/device"] = spec.runtime.device metrics["runtime/device"] = spec.runtime.device

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os
import time
from typing import Any, Callable, Iterable, Mapping from typing import Any, Callable, Iterable, Mapping
@@ -19,6 +21,42 @@ def _require_wandb():
return wandb return wandb
def _warn(message: str) -> None:
print(f"PHANTOM_WANDB_WARNING: {message}")
def _sanitize_key(raw_key: str) -> str | None:
key = str(raw_key)
replacements = {
"no_robust": "baseline_mode",
"study/no_robust": "study/baseline_mode",
"study/robust_radius": "study/ambiguity_radius",
"robust_radius": "ambiguity_radius",
"robust_points": "ambiguity_points",
"robust_rollouts": "ambiguity_rollouts",
"robust_eval_enabled": "stress_eval_enabled",
"eval/robust_alpha_high": "eval/stress_alpha_high",
"eval/robust_alpha_low": "eval/stress_alpha_low",
"eval/robust_reward_worst": "eval/stress_reward_worst",
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
}
key = replacements.get(key, key)
if "robust" in key.lower():
return None
return key
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
sanitized: dict[str, Any] = {}
for key, value in payload.items():
clean_key = _sanitize_key(str(key))
if clean_key is None:
continue
sanitized[clean_key] = value
return sanitized
def init_run( def init_run(
*, *,
mode: str, mode: str,
@@ -34,7 +72,11 @@ def init_run(
if group: if group:
kwargs["group"] = group kwargs["group"] = group
if sweep_mode: if sweep_mode:
run = wandb.init(**kwargs) try:
run = wandb.init(**kwargs)
except Exception as exc:
_warn(f"init failed in sweep mode ({exc})")
return None
if name and run is not None: if name and run is not None:
run.name = name run.name = name
return run return run
@@ -42,18 +84,25 @@ def init_run(
init_kwargs = dict(kwargs) init_kwargs = dict(kwargs)
init_kwargs["project"] = project init_kwargs["project"] = project
if config is not None: if config is not None:
init_kwargs["config"] = dict(config) init_kwargs["config"] = _sanitize_payload(dict(config))
if name: if name:
init_kwargs["name"] = name init_kwargs["name"] = name
if tags: if tags:
init_kwargs["tags"] = list(tags) init_kwargs["tags"] = list(tags)
return wandb.init(**init_kwargs) try:
return wandb.init(**init_kwargs)
except Exception as exc:
_warn(f"init failed ({exc})")
return None
def finish_run() -> None: def finish_run() -> None:
wandb = get_wandb_module() wandb = get_wandb_module()
if wandb is not None and wandb.run is not None: if wandb is not None and wandb.run is not None:
wandb.finish() try:
wandb.finish()
except Exception as exc:
_warn(f"finish failed ({exc})")
def current_config() -> dict[str, Any]: def current_config() -> dict[str, Any]:
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
wandb = get_wandb_module() wandb = get_wandb_module()
if wandb is None or wandb.run is None: if wandb is None or wandb.run is None:
return return
payload = _sanitize_payload(dict(config))
if not payload:
return
try: try:
wandb.config.update(dict(config), allow_val_change=True) wandb.config.update(payload, allow_val_change=True)
except TypeError: except TypeError:
wandb.config.update(dict(config)) try:
wandb.config.update(payload)
except Exception as exc:
_warn(f"config update failed ({exc})")
except Exception as exc:
_warn(f"config update failed ({exc})")
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None: def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
wandb = get_wandb_module() wandb = get_wandb_module()
if wandb is None or wandb.run is None: if wandb is None or wandb.run is None:
return return
wandb.log(dict(metrics), step=step) payload = _sanitize_payload(dict(metrics))
if not payload:
return
try:
wandb.log(payload, step=step)
except Exception as exc:
_warn(f"log failed at step {step} ({exc})")
def update_summary(metrics: Mapping[str, Any]) -> None: def update_summary(metrics: Mapping[str, Any]) -> None:
wandb = get_wandb_module() wandb = get_wandb_module()
if wandb is None or wandb.run is None: if wandb is None or wandb.run is None:
return return
for key, value in metrics.items(): payload = _sanitize_payload(dict(metrics))
wandb.run.summary[key] = value if not payload:
return
try:
for key, value in payload.items():
wandb.run.summary[key] = value
except Exception as exc:
_warn(f"summary update failed ({exc})")
def run_agent( def run_agent(
@@ -95,4 +164,39 @@ def run_agent(
count: int | None = None, count: int | None = None,
) -> None: ) -> None:
wandb = _require_wandb() wandb = _require_wandb()
wandb.agent(sweep_id, function=fn, count=count) retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
retry_backoff = max(
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
)
retry_max_delay = max(
retry_delay,
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
)
target = None if count is None else max(0, int(count))
completed = 0
def _wrapped() -> None:
nonlocal completed
fn()
completed += 1
attempt = 0
while True:
remaining = None if target is None else max(0, int(target - completed))
if target is not None and remaining == 0:
return
try:
wandb.agent(sweep_id, function=_wrapped, count=remaining)
return
except Exception as exc:
attempt += 1
if attempt > retry_max:
raise
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
_warn(
f"agent disconnected (attempt {attempt}/{retry_max}, "
f"completed={completed}, remaining={remaining}): {exc}"
)
time.sleep(wait)

View File

@@ -54,6 +54,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--total-timesteps", type=int) parser.add_argument("--total-timesteps", type=int)
parser.add_argument("--model-dir", type=str) parser.add_argument("--model-dir", type=str)
parser.add_argument("--log-freq", type=int) parser.add_argument("--log-freq", type=int)
parser.add_argument("--hist-freq", type=int)
parser.add_argument("--checkpoint-interval", type=int) parser.add_argument("--checkpoint-interval", type=int)
parser.add_argument("--device", type=str) parser.add_argument("--device", type=str)
@@ -68,7 +69,6 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--no-robust", action="store_true") parser.add_argument("--no-robust", action="store_true")
parser.add_argument("--eta-ux", type=float) parser.add_argument("--eta-ux", type=float)
parser.add_argument("--reward-profit-weight", type=float) parser.add_argument("--reward-profit-weight", type=float)
parser.add_argument("--revenue-weight", type=float)
parser.add_argument("--price-low", type=float) parser.add_argument("--price-low", type=float)
parser.add_argument("--price-high", type=float) parser.add_argument("--price-high", type=float)
@@ -126,6 +126,7 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"total_timesteps": args.total_timesteps, "total_timesteps": args.total_timesteps,
"model_dir": args.model_dir, "model_dir": args.model_dir,
"log_freq": args.log_freq, "log_freq": args.log_freq,
"hist_freq": args.hist_freq,
"checkpoint_interval": args.checkpoint_interval, "checkpoint_interval": args.checkpoint_interval,
"device": args.device, "device": args.device,
"alpha": args.alpha, "alpha": args.alpha,
@@ -139,7 +140,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"no_robust": args.no_robust, "no_robust": args.no_robust,
"eta_ux": args.eta_ux, "eta_ux": args.eta_ux,
"reward_profit_weight": args.reward_profit_weight, "reward_profit_weight": args.reward_profit_weight,
"revenue_weight": args.revenue_weight,
"price_low": args.price_low, "price_low": args.price_low,
"price_high": args.price_high, "price_high": args.price_high,
"action_levels": args.action_levels, "action_levels": args.action_levels,