mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -132,15 +132,15 @@ def evaluate(
|
||||
shifted_env.close()
|
||||
shifted_rows.append((tag, alpha, shifted_metrics))
|
||||
|
||||
metrics["eval/robust_alpha_low"] = low_alpha
|
||||
metrics["eval/robust_alpha_high"] = high_alpha
|
||||
metrics["eval/robust_reward_worst"] = float(
|
||||
metrics["eval/stress_alpha_low"] = low_alpha
|
||||
metrics["eval/stress_alpha_high"] = high_alpha
|
||||
metrics["eval/stress_reward_worst"] = float(
|
||||
min(row[2]["eval/reward_mean"] for row in shifted_rows)
|
||||
)
|
||||
metrics["eval/robust_revenue_worst"] = float(
|
||||
metrics["eval/stress_revenue_worst"] = float(
|
||||
min(row[2]["eval/revenue_mean"] for row in shifted_rows)
|
||||
)
|
||||
metrics["eval/robust_coi_leakage_worst"] = float(
|
||||
metrics["eval/stress_coi_leakage_worst"] = float(
|
||||
max(row[2]["eval/coi_leakage_mean"] for row in shifted_rows)
|
||||
)
|
||||
for tag, alpha, shifted_metrics in shifted_rows:
|
||||
|
||||
@@ -80,7 +80,11 @@ def train_qtable(
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
if wandb_live:
|
||||
try:
|
||||
wandb.log(dict(event), step=step_offset + int(steps))
|
||||
except Exception:
|
||||
wandb_live = False
|
||||
train_events.append(event)
|
||||
else:
|
||||
train_events.append(event)
|
||||
if console_progress:
|
||||
@@ -113,7 +117,11 @@ def train_qtable(
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
if wandb_live:
|
||||
try:
|
||||
wandb.log(dict(tail_event), step=step_offset + int(steps))
|
||||
except Exception:
|
||||
wandb_live = False
|
||||
train_events.append(tail_event)
|
||||
else:
|
||||
train_events.append(tail_event)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..lib.callbacks import MetricsCallback
|
||||
from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
@@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any):
|
||||
|
||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
try:
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models") from exc
|
||||
@@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
metrics_callback = MetricsCallback(
|
||||
log_histograms=False,
|
||||
log_histograms=True,
|
||||
log_freq=int(cfg["log_freq"]),
|
||||
hist_freq=int(cfg.get("hist_freq", 500)),
|
||||
step_offset=int(cfg.get("wandb_step_offset", 0)),
|
||||
)
|
||||
callbacks = [metrics_callback]
|
||||
callbacks.append(
|
||||
EvalCallback(
|
||||
eval_callback = EvalMetricsCallback(
|
||||
eval_env,
|
||||
eval_freq=int(cfg["eval_freq"]),
|
||||
n_eval_episodes=int(cfg["eval_episodes"]),
|
||||
step_offset=int(cfg.get("wandb_step_offset", 0)),
|
||||
deterministic=True,
|
||||
verbose=0,
|
||||
)
|
||||
)
|
||||
callbacks = [metrics_callback, eval_callback]
|
||||
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
@@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
model_path = model_dir / f"phantom_{cfg['algo']}"
|
||||
model.save(str(model_path))
|
||||
|
||||
artifact_name = checkpoint_artifact_name(
|
||||
cfg,
|
||||
backend="sb3",
|
||||
sweep_id=os.getenv("WANDB_SWEEP_ID"),
|
||||
)
|
||||
artifact_logged = False
|
||||
try:
|
||||
artifact_logged = bool(
|
||||
log_checkpoint_file(
|
||||
artifact_name,
|
||||
file_path=model_path.with_suffix(".zip"),
|
||||
artifact_file_name="model.zip",
|
||||
metadata={
|
||||
"algo": str(cfg.get("algo", "ppo")),
|
||||
"backend": "sb3",
|
||||
"seed": int(cfg.get("seed", 0)),
|
||||
"step": int(getattr(model, "num_timesteps", 0)),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
artifact_logged = False
|
||||
|
||||
metrics: dict[str, Any] = evaluate(
|
||||
model,
|
||||
eval_env,
|
||||
@@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
)
|
||||
metrics["train/global_step"] = int(model.num_timesteps)
|
||||
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
||||
metrics["_train_events"] = list(metrics_callback.events)
|
||||
metrics["model/artifact_name"] = str(artifact_name)
|
||||
metrics["model/artifact_logged"] = float(artifact_logged)
|
||||
metrics["_train_events"] = sorted(
|
||||
[*metrics_callback.events, *eval_callback.events],
|
||||
key=lambda event: int(event.get("train/global_step", 0)),
|
||||
)
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
|
||||
@@ -45,6 +45,10 @@ def _log(message: str) -> None:
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def _wandb_run_active() -> bool:
|
||||
return bool(HAS_WANDB and getattr(wandb, "run", None) is not None)
|
||||
|
||||
|
||||
def _parse_list(raw: str) -> list[str]:
|
||||
return [x.strip().lower() for x in str(raw).split(",") if x.strip()]
|
||||
|
||||
@@ -61,6 +65,10 @@ def _truthy(value: str | bool | None) -> bool:
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _mode_label_from_baseline(is_baseline: bool) -> str:
|
||||
return "baseline" if bool(is_baseline) else "defended"
|
||||
|
||||
|
||||
def _action(policy, obs: np.ndarray):
|
||||
out = policy.predict(obs, deterministic=True)
|
||||
action = out[0] if isinstance(out, tuple) else out
|
||||
@@ -166,7 +174,7 @@ def _log_train_events(
|
||||
alpha: float,
|
||||
step_offset: int,
|
||||
) -> int:
|
||||
if not (HAS_WANDB and wandb.run is not None):
|
||||
if not _wandb_run_active():
|
||||
return int(step_offset)
|
||||
if not events:
|
||||
return int(step_offset)
|
||||
@@ -187,11 +195,14 @@ def _log_train_events(
|
||||
"run.kind": "benchmark",
|
||||
"runtime/backend": tier_name,
|
||||
"study/mode": mode_label,
|
||||
"study/no_robust": float(mode_label == "no_robust"),
|
||||
"study/baseline_mode": float(mode_label == "baseline"),
|
||||
"study/alpha": float(alpha),
|
||||
}
|
||||
)
|
||||
try:
|
||||
wandb.log(payload, step=cursor + rel_step)
|
||||
except Exception:
|
||||
return int(step_offset)
|
||||
max_rel = max(max(1, int(evt.get("train/global_step", 0))) for evt in ordered)
|
||||
return cursor + max_rel + 1
|
||||
|
||||
@@ -203,6 +214,7 @@ def run_benchmark(
|
||||
n_episodes: int,
|
||||
mode_label: str,
|
||||
step_cursor_start: int = 0,
|
||||
eval_alpha_values: list[float] | None = None,
|
||||
):
|
||||
from .backends.common import make_env
|
||||
|
||||
@@ -239,14 +251,22 @@ def run_benchmark(
|
||||
"dqn",
|
||||
}:
|
||||
wandb_step_cursor += max(1, int(cfg.get("total_timesteps", 1))) + 1
|
||||
env = make_env({**cfg, "alpha": float(alpha)})
|
||||
eval_targets = (
|
||||
[float(value) for value in eval_alpha_values]
|
||||
if eval_alpha_values
|
||||
else [float(alpha)]
|
||||
)
|
||||
for eval_alpha in eval_targets:
|
||||
env = make_env({**cfg, "alpha": float(eval_alpha)})
|
||||
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
|
||||
env.close()
|
||||
|
||||
row = {
|
||||
"tier": tier_name,
|
||||
"mode": mode_label,
|
||||
"alpha": float(alpha),
|
||||
"alpha": float(eval_alpha),
|
||||
"train_alpha": float(alpha),
|
||||
"eval_alpha": float(eval_alpha),
|
||||
"episodes": int(n_episodes),
|
||||
"mean_reward": float(np.mean([e["reward"] for e in eps])),
|
||||
"mean_revenue": float(np.mean([e["revenue"] for e in eps])),
|
||||
@@ -257,7 +277,8 @@ def run_benchmark(
|
||||
row["objective_score"] = row["mean_reward"]
|
||||
rows.append(row)
|
||||
_log(
|
||||
f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: "
|
||||
f"[{run_index}/{total_runs}] train_alpha={float(alpha):.2f} "
|
||||
f"eval_alpha={float(eval_alpha):.2f} tier={tier_name}: "
|
||||
f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} "
|
||||
f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}"
|
||||
)
|
||||
@@ -266,25 +287,32 @@ def run_benchmark(
|
||||
step_means = []
|
||||
for step in range(max_len):
|
||||
vals = [
|
||||
e["price_trace"][step] for e in eps if step < len(e["price_trace"])
|
||||
e["price_trace"][step]
|
||||
for e in eps
|
||||
if step < len(e["price_trace"])
|
||||
]
|
||||
step_means.append(float(np.mean(vals)) if vals else np.nan)
|
||||
traces.append(
|
||||
{
|
||||
"tier": tier_name,
|
||||
"alpha": float(alpha),
|
||||
"alpha": float(eval_alpha),
|
||||
"train_alpha": float(alpha),
|
||||
"eval_alpha": float(eval_alpha),
|
||||
"mean_price_trace": step_means,
|
||||
}
|
||||
)
|
||||
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
if _wandb_run_active():
|
||||
try:
|
||||
wandb.log(
|
||||
{
|
||||
"run.kind": "benchmark",
|
||||
"runtime/backend": tier_name,
|
||||
"study/mode": mode_label,
|
||||
"study/no_robust": float(mode_label == "no_robust"),
|
||||
"study/alpha": float(alpha),
|
||||
"study/baseline_mode": float(mode_label == "baseline"),
|
||||
"study/alpha": float(eval_alpha),
|
||||
"study/train_alpha": float(alpha),
|
||||
"study/eval_alpha": float(eval_alpha),
|
||||
"eval/reward_mean": row["mean_reward"],
|
||||
"eval/revenue_mean": row["mean_revenue"],
|
||||
"eval/margin_mean": row["mean_margin"],
|
||||
@@ -294,6 +322,8 @@ def run_benchmark(
|
||||
},
|
||||
step=wandb_step_cursor,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
wandb_step_cursor += 1
|
||||
|
||||
return pd.DataFrame(rows), traces, int(wandb_step_cursor)
|
||||
@@ -378,7 +408,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
if compare_robust_override is not None
|
||||
else _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
|
||||
)
|
||||
robust_modes = [False, True] if compare_robust else [bool(args.no_robust)]
|
||||
baseline_modes = [False, True] if compare_robust else [bool(args.no_robust)]
|
||||
|
||||
base_overrides = {
|
||||
"seed": args.seed,
|
||||
@@ -389,6 +419,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"robust_rollouts": args.robust_rollouts,
|
||||
"margin_floor": args.margin_floor,
|
||||
"eta_ux": args.eta_ux,
|
||||
"reward_profit_weight": args.reward_profit_weight,
|
||||
"price_low": args.price_low,
|
||||
@@ -405,12 +436,20 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
}
|
||||
tiers = _parse_list(args.tiers)
|
||||
alpha_values = _parse_float_list(args.alpha_values)
|
||||
eval_alpha_values = (
|
||||
_parse_float_list(args.eval_alpha_values)
|
||||
if str(getattr(args, "eval_alpha_values", "")).strip()
|
||||
else []
|
||||
)
|
||||
_log(
|
||||
"starting run "
|
||||
+ json.dumps(
|
||||
{
|
||||
"tiers": tiers,
|
||||
"alpha_values": alpha_values,
|
||||
"eval_alpha_values": (
|
||||
eval_alpha_values if eval_alpha_values else alpha_values
|
||||
),
|
||||
"episodes": int(args.episodes),
|
||||
"total_timesteps": int(args.total_timesteps),
|
||||
"device": str(args.device),
|
||||
@@ -421,14 +460,14 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
all_frames: list[pd.DataFrame] = []
|
||||
all_traces: list[dict] = []
|
||||
wandb_step_cursor = 0
|
||||
for no_robust in robust_modes:
|
||||
for baseline_mode in baseline_modes:
|
||||
overrides = dict(base_overrides)
|
||||
overrides["no_robust"] = bool(no_robust)
|
||||
overrides["baseline_mode"] = bool(baseline_mode)
|
||||
cfg = TrainSpec.from_flat(
|
||||
{k: v for k, v in overrides.items() if v is not None}
|
||||
).to_flat_dict()
|
||||
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
|
||||
mode_label = "no_robust" if no_robust else "robust"
|
||||
mode_label = _mode_label_from_baseline(bool(baseline_mode))
|
||||
_log(f"mode={mode_label}: begin")
|
||||
df_mode, traces_mode, wandb_step_cursor = run_benchmark(
|
||||
cfg,
|
||||
@@ -437,6 +476,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
args.episodes,
|
||||
mode_label=mode_label,
|
||||
step_cursor_start=wandb_step_cursor,
|
||||
eval_alpha_values=eval_alpha_values,
|
||||
)
|
||||
_log(f"mode={mode_label}: complete ({len(df_mode)} rows)")
|
||||
for trace in traces_mode:
|
||||
@@ -465,7 +505,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
|
||||
+ json.dumps(
|
||||
{
|
||||
"tier": best["tier"],
|
||||
"mode": best.get("mode", "robust"),
|
||||
"mode": best.get("mode", "defended"),
|
||||
"alpha": float(best["alpha"]),
|
||||
"objective_score": float(best["objective_score"]),
|
||||
"mean_revenue": float(best["mean_revenue"]),
|
||||
@@ -486,6 +526,7 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
parser.add_argument("--project", default="capstone")
|
||||
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")
|
||||
parser.add_argument("--alpha-values", default="0.0,0.3,0.6")
|
||||
parser.add_argument("--eval-alpha-values", default="")
|
||||
parser.add_argument("--episodes", type=int, default=10)
|
||||
parser.add_argument("--output-dir", default="engine/studies/results")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
@@ -496,6 +537,7 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
parser.add_argument("--robust-radius", type=float, default=0.15)
|
||||
parser.add_argument("--robust-points", type=int, default=5)
|
||||
parser.add_argument("--robust-rollouts", type=int, default=1)
|
||||
parser.add_argument("--margin-floor", type=float, default=0.85)
|
||||
parser.add_argument("--eta-ux", type=float, default=0.5)
|
||||
parser.add_argument("--reward-profit-weight", type=float, default=1.0)
|
||||
parser.add_argument("--price-low", type=float, default=10.0)
|
||||
@@ -529,35 +571,47 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
key_to_attr = {
|
||||
"tiers": "tiers",
|
||||
"alpha_values": "alpha_values",
|
||||
"eval_alpha_values": "eval_alpha_values",
|
||||
"episodes": "episodes",
|
||||
"total_timesteps": "total_timesteps",
|
||||
"lambda_coi": "lambda_coi",
|
||||
"robust_radius": "robust_radius",
|
||||
"robust_points": "robust_points",
|
||||
"robust_rollouts": "robust_rollouts",
|
||||
"ambiguity_radius": "robust_radius",
|
||||
"ambiguity_points": "robust_points",
|
||||
"ambiguity_rollouts": "robust_rollouts",
|
||||
"eta_ux": "eta_ux",
|
||||
"reward_profit_weight": "reward_profit_weight",
|
||||
"learning_rate": "learning_rate",
|
||||
"batch_size": "batch_size",
|
||||
"n_steps": "n_steps",
|
||||
"baseline_mode": "no_robust",
|
||||
"no_robust": "no_robust",
|
||||
"margin_floor": "margin_floor",
|
||||
"device": "device",
|
||||
}
|
||||
for key in (
|
||||
"tiers",
|
||||
"alpha_values",
|
||||
"eval_alpha_values",
|
||||
"episodes",
|
||||
"total_timesteps",
|
||||
"lambda_coi",
|
||||
"robust_radius",
|
||||
"robust_points",
|
||||
"robust_rollouts",
|
||||
"ambiguity_radius",
|
||||
"ambiguity_points",
|
||||
"ambiguity_rollouts",
|
||||
"eta_ux",
|
||||
"reward_profit_weight",
|
||||
"learning_rate",
|
||||
"batch_size",
|
||||
"n_steps",
|
||||
"baseline_mode",
|
||||
"no_robust",
|
||||
"margin_floor",
|
||||
"device",
|
||||
):
|
||||
if key in wandb.config:
|
||||
@@ -582,16 +636,16 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
alpha_values = _parse_float_list(args.alpha_values)
|
||||
run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S")
|
||||
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
|
||||
compare_tag = "robust-compare" if compare_enabled else "single-mode"
|
||||
compare_tag = "defended-compare" if compare_enabled else "single-mode"
|
||||
modes = (
|
||||
[("no_robust", True), ("robust", False)]
|
||||
[("baseline", True), ("defended", False)]
|
||||
if compare_enabled
|
||||
else [("no_robust" if bool(args.no_robust) else "robust", bool(args.no_robust))]
|
||||
else [(_mode_label_from_baseline(bool(args.no_robust)), bool(args.no_robust))]
|
||||
)
|
||||
|
||||
run_idx = 0
|
||||
for tier in tiers:
|
||||
for mode_label, no_robust in modes:
|
||||
for mode_label, baseline_mode in modes:
|
||||
for alpha in alpha_values:
|
||||
run_idx += 1
|
||||
alpha_token = (
|
||||
@@ -600,7 +654,7 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
tier_args = argparse.Namespace(**vars(args))
|
||||
tier_args.tiers = tier
|
||||
tier_args.alpha_values = str(float(alpha))
|
||||
tier_args.no_robust = bool(no_robust)
|
||||
tier_args.no_robust = bool(baseline_mode)
|
||||
run = wandb.init(
|
||||
project=args.project,
|
||||
name=(
|
||||
@@ -617,16 +671,19 @@ def run_cli(raw_args: list[str] | None = None):
|
||||
"run.kind": "benchmark",
|
||||
"runtime/backend": tier,
|
||||
"study/mode": mode_label,
|
||||
"study/no_robust": float(no_robust),
|
||||
"study/baseline_mode": float(baseline_mode),
|
||||
"study/alpha": float(alpha),
|
||||
"tiers": tier,
|
||||
"alpha_values": str(float(alpha)),
|
||||
"eval_alpha_values": args.eval_alpha_values,
|
||||
"episodes": args.episodes,
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"lambda_coi": args.lambda_coi,
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"robust_rollouts": args.robust_rollouts,
|
||||
"ambiguity_radius": args.robust_radius,
|
||||
"ambiguity_points": args.robust_points,
|
||||
"ambiguity_rollouts": args.robust_rollouts,
|
||||
"margin_floor": args.margin_floor,
|
||||
"baseline_mode": float(baseline_mode),
|
||||
"eta_ux": args.eta_ux,
|
||||
"reward_profit_weight": args.reward_profit_weight,
|
||||
"learning_rate": args.learning_rate,
|
||||
|
||||
@@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback):
|
||||
self,
|
||||
log_histograms: bool = False,
|
||||
log_freq: int = 100,
|
||||
hist_freq: int = 500,
|
||||
step_offset: int = 0,
|
||||
verbose: int = 0,
|
||||
):
|
||||
super().__init__(verbose)
|
||||
self.log_histograms = log_histograms
|
||||
self.log_freq = max(1, int(log_freq))
|
||||
self.hist_freq = max(1, int(hist_freq))
|
||||
self.step_offset = max(0, int(step_offset))
|
||||
self._wandb = get_wandb_module()
|
||||
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
|
||||
self._price_samples: list[float] = []
|
||||
self._demand_samples: list[float] = []
|
||||
self._window_sums = {
|
||||
"train/revenue_mean": 0.0,
|
||||
"train/margin_mean": 0.0,
|
||||
@@ -74,9 +78,67 @@ class MetricsCallback(BaseCallback):
|
||||
)
|
||||
self._window_count += 1
|
||||
|
||||
def _flush(self, step: int) -> None:
|
||||
if self._window_count <= 0:
|
||||
def _accumulate_histograms(self, info: dict[str, Any]) -> None:
|
||||
if not self.log_histograms:
|
||||
return
|
||||
|
||||
for key in ("effective_prices", "prices"):
|
||||
if key not in info:
|
||||
continue
|
||||
try:
|
||||
values = np.asarray(info.get(key), dtype=float).reshape(-1)
|
||||
except Exception:
|
||||
continue
|
||||
if values.size <= 0:
|
||||
continue
|
||||
finite_values = values[np.isfinite(values)]
|
||||
if finite_values.size > 0:
|
||||
self._price_samples.extend(finite_values.tolist())
|
||||
break
|
||||
|
||||
if "demand" in info:
|
||||
try:
|
||||
demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1)
|
||||
except Exception:
|
||||
demand_values = np.array([], dtype=float)
|
||||
if demand_values.size > 0:
|
||||
finite_demand = demand_values[np.isfinite(demand_values)]
|
||||
if finite_demand.size > 0:
|
||||
self._demand_samples.extend(finite_demand.tolist())
|
||||
|
||||
def _flush_histograms(self, step: int, force: bool = False) -> None:
|
||||
if not self.log_histograms:
|
||||
return
|
||||
if not force and step % self.hist_freq != 0:
|
||||
return
|
||||
if not self._price_samples and not self._demand_samples:
|
||||
return
|
||||
if self._wandb is None:
|
||||
self._price_samples.clear()
|
||||
self._demand_samples.clear()
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
if self._price_samples:
|
||||
payload["train/price_dist"] = self._wandb.Histogram(
|
||||
np.asarray(self._price_samples, dtype=np.float32)
|
||||
)
|
||||
if self._demand_samples:
|
||||
payload["train/demand_dist"] = self._wandb.Histogram(
|
||||
np.asarray(self._demand_samples, dtype=np.float32)
|
||||
)
|
||||
|
||||
if payload and self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(payload, step=self.step_offset + int(step))
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
|
||||
self._price_samples.clear()
|
||||
self._demand_samples.clear()
|
||||
|
||||
def _flush(self, step: int, *, force_hist: bool = False) -> None:
|
||||
if self._window_count > 0:
|
||||
denom = float(self._window_count)
|
||||
payload = {
|
||||
key: (value / denom)
|
||||
@@ -92,17 +154,24 @@ class MetricsCallback(BaseCallback):
|
||||
}
|
||||
payload["train/global_step"] = int(step)
|
||||
if self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(dict(payload), step=self.step_offset + int(step))
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
self.events.append(payload)
|
||||
else:
|
||||
self.events.append(payload)
|
||||
for key in self._window_sums:
|
||||
self._window_sums[key] = 0.0
|
||||
self._window_count = 0
|
||||
|
||||
self._flush_histograms(step=step, force=force_hist)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for info in self.locals.get("infos", []):
|
||||
if isinstance(info, dict):
|
||||
self._accumulate(info)
|
||||
self._accumulate_histograms(info)
|
||||
|
||||
if self.num_timesteps % self.log_freq == 0:
|
||||
self._flush(step=self.num_timesteps)
|
||||
@@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback):
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
self._flush(step=self.num_timesteps)
|
||||
self._flush(step=self.num_timesteps, force_hist=True)
|
||||
|
||||
|
||||
class EvalMetricsCallback(EvalCallback):
|
||||
"""Deterministic evaluation collector detached from logging backends."""
|
||||
|
||||
def __init__(
|
||||
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
||||
self,
|
||||
eval_env,
|
||||
eval_freq: int = 1000,
|
||||
n_eval_episodes: int = 5,
|
||||
step_offset: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
||||
)
|
||||
self._eval_revenues: list[float] = []
|
||||
self.step_offset = max(0, int(step_offset))
|
||||
self._wandb = get_wandb_module()
|
||||
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
|
||||
self._eval_stats: dict[str, list[float]] = {
|
||||
"eval/revenue_mean": [],
|
||||
"eval/margin_mean": [],
|
||||
"eval/coi_level_mean": [],
|
||||
"eval/coi_leakage_mean": [],
|
||||
"eval/volatility_mean": [],
|
||||
"eval/agent_prob_mean": [],
|
||||
}
|
||||
self.events: list[dict[str, float | int]] = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
result = super()._on_step()
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
self.events.append(
|
||||
{
|
||||
payload: dict[str, float | int] = {
|
||||
"eval/reward_mean": float(self.last_mean_reward),
|
||||
"eval/revenue_mean": float(np.mean(self._eval_revenues))
|
||||
if self._eval_revenues
|
||||
else 0.0,
|
||||
"train/global_step": int(self.num_timesteps),
|
||||
}
|
||||
for key, values in self._eval_stats.items():
|
||||
payload[key] = float(np.mean(values)) if values else 0.0
|
||||
|
||||
if self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(
|
||||
dict(payload),
|
||||
step=self.step_offset + int(self.num_timesteps),
|
||||
)
|
||||
self._eval_revenues = []
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
self.events.append(payload)
|
||||
else:
|
||||
self.events.append(payload)
|
||||
|
||||
for values in self._eval_stats.values():
|
||||
values.clear()
|
||||
|
||||
return result
|
||||
|
||||
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
|
||||
# called after each eval episode
|
||||
info = locals_.get("info", {})
|
||||
if "economics" in info:
|
||||
self._eval_revenues.append(info["economics"]["revenue"])
|
||||
econ = info.get("economics") if isinstance(info, dict) else None
|
||||
if not isinstance(econ, dict):
|
||||
return
|
||||
|
||||
self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0)))
|
||||
self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0)))
|
||||
self._eval_stats["eval/coi_level_mean"].append(
|
||||
float(econ.get("coi_level", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/coi_leakage_mean"].append(
|
||||
float(econ.get("coi_leakage", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/volatility_mean"].append(
|
||||
float(econ.get("volatility", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/agent_prob_mean"].append(
|
||||
float(econ.get("agent_prob", 0.0))
|
||||
)
|
||||
|
||||
@@ -156,6 +156,7 @@ class ProviderBenchmark:
|
||||
|
||||
# log to wandb if available
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
try:
|
||||
wandb.log(
|
||||
{
|
||||
f"benchmark/{name}/revenue": result.mean_revenue,
|
||||
@@ -164,6 +165,8 @@ class ProviderBenchmark:
|
||||
"benchmark/alpha": alpha,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return self.results
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from ..telemetry.wandb import (
|
||||
get_wandb_module,
|
||||
init_run,
|
||||
run_agent,
|
||||
update_summary,
|
||||
)
|
||||
from .train import run_with_active_sweep_run
|
||||
|
||||
@@ -43,6 +44,7 @@ def run_sweep_agent(
|
||||
spec = TrainSpec.from_flat(merged)
|
||||
if run is not None:
|
||||
run.name = run_name(spec, kind=kind, scenario=scenario)
|
||||
try:
|
||||
run_with_active_sweep_run(
|
||||
spec,
|
||||
kind=kind,
|
||||
@@ -50,6 +52,15 @@ def run_sweep_agent(
|
||||
group=group,
|
||||
extra_tags=extra_tags,
|
||||
)
|
||||
update_summary({"run/status": "finished"})
|
||||
except Exception as exc:
|
||||
update_summary(
|
||||
{
|
||||
"run/status": "crashed",
|
||||
"run/error": str(exc),
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
finish_run()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def _tags_for_run(spec: TrainSpec, kind: str, extra_tags: Sequence[str]) -> list
|
||||
kind,
|
||||
spec.algorithm.name,
|
||||
spec.runtime.backend,
|
||||
"vanilla" if spec.study.no_robust else "robust",
|
||||
"baseline" if spec.study.no_robust else "defended",
|
||||
]
|
||||
tags.extend([tag for tag in extra_tags if tag])
|
||||
return tags
|
||||
|
||||
@@ -91,6 +91,44 @@
|
||||
"command": "bash scripts/nx_research.sh docker-train-publish",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"whoclicked-publish": {
|
||||
"executor": "nx:run-commands",
|
||||
"dependsOn": [
|
||||
"install"
|
||||
],
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh whoclicked-publish",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"tpu-ray-bootstrap": {
|
||||
"executor": "nx:run-commands",
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh tpu-ray-bootstrap",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"tpu-ray-deps": {
|
||||
"executor": "nx:run-commands",
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh tpu-ray-deps",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"tpu-ray-verify": {
|
||||
"executor": "nx:run-commands",
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh tpu-ray-verify",
|
||||
"cwd": "."
|
||||
}
|
||||
},
|
||||
"tpu-ray-teardown": {
|
||||
"executor": "nx:run-commands",
|
||||
"options": {
|
||||
"command": "bash scripts/nx_research.sh tpu-ray-teardown",
|
||||
"cwd": "."
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
|
||||
@@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"study.robust_radius": "robust_radius",
|
||||
"study.robust_points": "robust_points",
|
||||
"study.robust_rollouts": "robust_rollouts",
|
||||
"study.ambiguity_radius": "robust_radius",
|
||||
"study.ambiguity_points": "robust_points",
|
||||
"study.ambiguity_rollouts": "robust_rollouts",
|
||||
"study.info_value": "info_value",
|
||||
"study.eta_ux": "eta_ux",
|
||||
"study.reward_profit_weight": "reward_profit_weight",
|
||||
"study.revenue_weight": "revenue_weight",
|
||||
"ambiguity_radius": "robust_radius",
|
||||
"ambiguity_points": "robust_points",
|
||||
"ambiguity_rollouts": "robust_rollouts",
|
||||
"baseline_mode": "no_robust",
|
||||
"stress_eval_enabled": "robust_eval_enabled",
|
||||
"optimizer.learning_rate": "learning_rate",
|
||||
"optimizer.gamma": "gamma",
|
||||
"optimizer.batch_size": "batch_size",
|
||||
@@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"runtime.seed": "seed",
|
||||
"runtime.total_timesteps": "total_timesteps",
|
||||
"runtime.checkpoint_interval": "checkpoint_interval",
|
||||
"runtime.hist_freq": "hist_freq",
|
||||
"eval.eval_freq": "eval_freq",
|
||||
"eval.eval_episodes": "eval_episodes",
|
||||
}
|
||||
@@ -86,7 +94,6 @@ class StudySpec:
|
||||
info_value: float = 1.0
|
||||
eta_ux: float = 0.5
|
||||
reward_profit_weight: float = 1.0
|
||||
revenue_weight: float = 0.01
|
||||
no_robust: bool = False
|
||||
|
||||
|
||||
@@ -128,6 +135,7 @@ class RuntimeSpec:
|
||||
checkpoint_interval: int = 200_000
|
||||
model_dir: str = "engine/models"
|
||||
log_freq: int = 100
|
||||
hist_freq: int = 500
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -159,6 +167,7 @@ class TrainSpec:
|
||||
"backend": self.runtime.backend,
|
||||
"device": self.runtime.device,
|
||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||
"hist_freq": self.runtime.hist_freq,
|
||||
"n_products": self.env.n_products,
|
||||
"N": self.env.n_sessions,
|
||||
"price_low": self.env.price_low,
|
||||
@@ -179,7 +188,6 @@ class TrainSpec:
|
||||
"info_value": self.study.info_value,
|
||||
"eta_ux": self.study.eta_ux,
|
||||
"reward_profit_weight": self.study.reward_profit_weight,
|
||||
"revenue_weight": self.study.revenue_weight,
|
||||
"no_robust": self.study.no_robust,
|
||||
"learning_rate": self.optimizer.learning_rate,
|
||||
"gamma": self.optimizer.gamma,
|
||||
@@ -262,7 +270,6 @@ class TrainSpec:
|
||||
info_value=float(base["info_value"]),
|
||||
eta_ux=float(base["eta_ux"]),
|
||||
reward_profit_weight=float(base["reward_profit_weight"]),
|
||||
revenue_weight=float(base["revenue_weight"]),
|
||||
no_robust=no_robust,
|
||||
),
|
||||
optimizer=OptimizerSpec(
|
||||
@@ -300,6 +307,7 @@ class TrainSpec:
|
||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||
model_dir=str(base["model_dir"]),
|
||||
log_freq=int(base["log_freq"]),
|
||||
hist_freq=int(base["hist_freq"]),
|
||||
),
|
||||
eval=EvalSpec(
|
||||
eval_freq=int(base["eval_freq"]),
|
||||
@@ -310,9 +318,11 @@ class TrainSpec:
|
||||
|
||||
|
||||
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
|
||||
alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".")
|
||||
mode = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
return (
|
||||
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
|
||||
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
|
||||
f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}"
|
||||
)
|
||||
|
||||
|
||||
@@ -324,6 +334,7 @@ def run_metadata(
|
||||
group: str | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
) -> dict[str, Any]:
|
||||
mode = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
metadata: dict[str, Any] = {
|
||||
"run.kind": str(kind),
|
||||
"run.algo": spec.algorithm.name,
|
||||
@@ -332,6 +343,10 @@ def run_metadata(
|
||||
"run.scenario": str(scenario),
|
||||
"run.seed": spec.runtime.seed,
|
||||
"run.tags": list(tags),
|
||||
"study/alpha": float(spec.study.alpha),
|
||||
"study/mode": mode,
|
||||
"study/baseline_mode": float(bool(spec.study.no_robust)),
|
||||
"tiers": spec.algorithm.name,
|
||||
}
|
||||
if group:
|
||||
metadata["run.group"] = group
|
||||
|
||||
@@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
|
||||
|
||||
eval_reward = (
|
||||
_as_float(
|
||||
metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")),
|
||||
metrics.get(
|
||||
"eval/stress_reward_worst",
|
||||
metrics.get(
|
||||
"eval/robust_reward_worst", metrics.get("eval/reward_mean")
|
||||
),
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
or 0.0
|
||||
@@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
|
||||
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
|
||||
|
||||
metrics["study/alpha"] = spec.study.alpha
|
||||
metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
metrics["study/baseline_mode"] = float(bool(spec.study.no_robust))
|
||||
metrics["study/lambda_coi"] = spec.study.lambda_coi
|
||||
metrics["study/robust_radius"] = spec.study.robust_radius
|
||||
metrics["study/ambiguity_radius"] = spec.study.robust_radius
|
||||
metrics["study/info_value"] = spec.study.info_value
|
||||
metrics["tiers"] = spec.algorithm.name
|
||||
|
||||
metrics["runtime/backend"] = spec.runtime.backend
|
||||
metrics["runtime/device"] = spec.runtime.device
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Iterable, Mapping
|
||||
|
||||
|
||||
@@ -19,6 +21,42 @@ def _require_wandb():
|
||||
return wandb
|
||||
|
||||
|
||||
def _warn(message: str) -> None:
|
||||
print(f"PHANTOM_WANDB_WARNING: {message}")
|
||||
|
||||
|
||||
def _sanitize_key(raw_key: str) -> str | None:
|
||||
key = str(raw_key)
|
||||
replacements = {
|
||||
"no_robust": "baseline_mode",
|
||||
"study/no_robust": "study/baseline_mode",
|
||||
"study/robust_radius": "study/ambiguity_radius",
|
||||
"robust_radius": "ambiguity_radius",
|
||||
"robust_points": "ambiguity_points",
|
||||
"robust_rollouts": "ambiguity_rollouts",
|
||||
"robust_eval_enabled": "stress_eval_enabled",
|
||||
"eval/robust_alpha_high": "eval/stress_alpha_high",
|
||||
"eval/robust_alpha_low": "eval/stress_alpha_low",
|
||||
"eval/robust_reward_worst": "eval/stress_reward_worst",
|
||||
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
|
||||
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
|
||||
}
|
||||
key = replacements.get(key, key)
|
||||
if "robust" in key.lower():
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, value in payload.items():
|
||||
clean_key = _sanitize_key(str(key))
|
||||
if clean_key is None:
|
||||
continue
|
||||
sanitized[clean_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def init_run(
|
||||
*,
|
||||
mode: str,
|
||||
@@ -34,7 +72,11 @@ def init_run(
|
||||
if group:
|
||||
kwargs["group"] = group
|
||||
if sweep_mode:
|
||||
try:
|
||||
run = wandb.init(**kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed in sweep mode ({exc})")
|
||||
return None
|
||||
if name and run is not None:
|
||||
run.name = name
|
||||
return run
|
||||
@@ -42,18 +84,25 @@ def init_run(
|
||||
init_kwargs = dict(kwargs)
|
||||
init_kwargs["project"] = project
|
||||
if config is not None:
|
||||
init_kwargs["config"] = dict(config)
|
||||
init_kwargs["config"] = _sanitize_payload(dict(config))
|
||||
if name:
|
||||
init_kwargs["name"] = name
|
||||
if tags:
|
||||
init_kwargs["tags"] = list(tags)
|
||||
try:
|
||||
return wandb.init(**init_kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed ({exc})")
|
||||
return None
|
||||
|
||||
|
||||
def finish_run() -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is not None and wandb.run is not None:
|
||||
try:
|
||||
wandb.finish()
|
||||
except Exception as exc:
|
||||
_warn(f"finish failed ({exc})")
|
||||
|
||||
|
||||
def current_config() -> dict[str, Any]:
|
||||
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
payload = _sanitize_payload(dict(config))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.config.update(dict(config), allow_val_change=True)
|
||||
wandb.config.update(payload, allow_val_change=True)
|
||||
except TypeError:
|
||||
wandb.config.update(dict(config))
|
||||
try:
|
||||
wandb.config.update(payload)
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
|
||||
|
||||
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
wandb.log(dict(metrics), step=step)
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.log(payload, step=step)
|
||||
except Exception as exc:
|
||||
_warn(f"log failed at step {step} ({exc})")
|
||||
|
||||
|
||||
def update_summary(metrics: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
for key, value in metrics.items():
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
for key, value in payload.items():
|
||||
wandb.run.summary[key] = value
|
||||
except Exception as exc:
|
||||
_warn(f"summary update failed ({exc})")
|
||||
|
||||
|
||||
def run_agent(
|
||||
@@ -95,4 +164,39 @@ def run_agent(
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
wandb = _require_wandb()
|
||||
wandb.agent(sweep_id, function=fn, count=count)
|
||||
retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
|
||||
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
|
||||
retry_backoff = max(
|
||||
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
|
||||
)
|
||||
retry_max_delay = max(
|
||||
retry_delay,
|
||||
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
|
||||
)
|
||||
|
||||
target = None if count is None else max(0, int(count))
|
||||
completed = 0
|
||||
|
||||
def _wrapped() -> None:
|
||||
nonlocal completed
|
||||
fn()
|
||||
completed += 1
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
remaining = None if target is None else max(0, int(target - completed))
|
||||
if target is not None and remaining == 0:
|
||||
return
|
||||
try:
|
||||
wandb.agent(sweep_id, function=_wrapped, count=remaining)
|
||||
return
|
||||
except Exception as exc:
|
||||
attempt += 1
|
||||
if attempt > retry_max:
|
||||
raise
|
||||
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
|
||||
_warn(
|
||||
f"agent disconnected (attempt {attempt}/{retry_max}, "
|
||||
f"completed={completed}, remaining={remaining}): {exc}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
|
||||
@@ -54,6 +54,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--total-timesteps", type=int)
|
||||
parser.add_argument("--model-dir", type=str)
|
||||
parser.add_argument("--log-freq", type=int)
|
||||
parser.add_argument("--hist-freq", type=int)
|
||||
parser.add_argument("--checkpoint-interval", type=int)
|
||||
parser.add_argument("--device", type=str)
|
||||
|
||||
@@ -68,7 +69,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--no-robust", action="store_true")
|
||||
parser.add_argument("--eta-ux", type=float)
|
||||
parser.add_argument("--reward-profit-weight", type=float)
|
||||
parser.add_argument("--revenue-weight", type=float)
|
||||
|
||||
parser.add_argument("--price-low", type=float)
|
||||
parser.add_argument("--price-high", type=float)
|
||||
@@ -126,6 +126,7 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"model_dir": args.model_dir,
|
||||
"log_freq": args.log_freq,
|
||||
"hist_freq": args.hist_freq,
|
||||
"checkpoint_interval": args.checkpoint_interval,
|
||||
"device": args.device,
|
||||
"alpha": args.alpha,
|
||||
@@ -139,7 +140,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||
"no_robust": args.no_robust,
|
||||
"eta_ux": args.eta_ux,
|
||||
"reward_profit_weight": args.reward_profit_weight,
|
||||
"revenue_weight": args.revenue_weight,
|
||||
"price_low": args.price_low,
|
||||
"price_high": args.price_high,
|
||||
"action_levels": args.action_levels,
|
||||
|
||||
Reference in New Issue
Block a user