first meaningful runs

This commit is contained in:
2026-03-08 21:37:13 +01:00
parent 4c658a93a7
commit 73a1dafc6e
5 changed files with 96 additions and 25 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import argparse
import json
import logging
import os
from datetime import datetime, UTC
from pathlib import Path
@@ -11,11 +12,17 @@ import numpy as np
import pandas as pd
from .lib.tiers import LinearElasticityPolicy, StaticPolicy, SurgePolicy
from .logging_utils import configure_logging
from .spec import TrainSpec
from .telemetry.wandb import get_wandb_module
wandb = get_wandb_module()
HAS_WANDB = wandb is not None
logger = logging.getLogger(__name__)
def _log(message: str) -> None:
logger.info(message)
def _parse_list(raw: str) -> list[str]:
@@ -78,8 +85,6 @@ def _run_eval_episode(env, policy) -> dict:
def _build_tier(name: str, cfg: dict, alpha: float):
from .backends.common import make_env
from .backends.qtable import train_qtable
from .backends.sb3 import train_sb3
tier = name.lower().strip()
run_cfg = dict(cfg)
@@ -111,10 +116,15 @@ def _build_tier(name: str, cfg: dict, alpha: float):
return policy
if tier == "qtable":
from .backends.qtable import train_qtable
run_cfg["console_progress"] = True
agent, _ = train_qtable(run_cfg)
return agent
if tier in {"ppo", "a2c", "dqn"}:
from .backends.sb3 import train_sb3
run_cfg["algo"] = tier
agent, _ = train_sb3(run_cfg)
return agent
@@ -129,9 +139,15 @@ def run_benchmark(
rows: list[dict] = []
traces: list[dict] = []
total_runs = max(1, len(alpha_values) * len(tiers))
run_index = 0
for alpha in alpha_values:
for tier_name in tiers:
run_index += 1
_log(
f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: training"
)
policy = _build_tier(tier_name, cfg, alpha)
env = make_env({**cfg, "alpha": float(alpha)})
eps = [_run_eval_episode(env, policy) for _ in range(int(n_episodes))]
@@ -152,6 +168,11 @@ def run_benchmark(
+ float(cfg.get("revenue_weight", 0.01)) * row["mean_revenue"]
)
rows.append(row)
_log(
f"[{run_index}/{total_runs}] alpha={float(alpha):.2f} tier={tier_name}: "
f"reward={row['mean_reward']:.3f} revenue={row['mean_revenue']:.3f} "
f"coi={row['mean_coi']:.4f} score={row['objective_score']:.3f}"
)
max_len = max((len(e["price_trace"]) for e in eps), default=0)
step_means = []
@@ -282,6 +303,18 @@ def _run_with_args(args):
}
tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values)
_log(
"starting run "
+ json.dumps(
{
"tiers": tiers,
"alpha_values": alpha_values,
"episodes": int(args.episodes),
"total_timesteps": int(args.total_timesteps),
"device": str(args.device),
}
)
)
all_frames: list[pd.DataFrame] = []
all_traces: list[dict] = []
@@ -292,8 +325,10 @@ def _run_with_args(args):
{k: v for k, v in overrides.items() if v is not None}
).to_flat_dict()
cfg["linear_warmup_steps"] = int(args.linear_warmup_steps)
df_mode, traces_mode = run_benchmark(cfg, tiers, alpha_values, args.episodes)
mode_label = "no_robust" if no_robust else "robust"
_log(f"mode={mode_label}: begin")
df_mode, traces_mode = run_benchmark(cfg, tiers, alpha_values, args.episodes)
_log(f"mode={mode_label}: complete ({len(df_mode)} rows)")
df_mode["mode"] = mode_label
for trace in traces_mode:
trace["mode"] = mode_label
@@ -311,11 +346,12 @@ def _run_with_args(args):
df.to_csv(csv_path, index=False)
trace_path.write_text(json.dumps(traces, indent=2))
rev_path, coi_path, price_path = _plot_outputs(df, traces, out_dir, stamp)
_log(f"artifacts written in {out_dir}")
if not df.empty:
best_idx = int(df["mean_revenue"].idxmax())
best = df.iloc[best_idx]
print(
_log(
"BEST_TIER="
+ json.dumps(
{
@@ -327,14 +363,15 @@ def _run_with_args(args):
}
)
)
print(f"BENCHMARK_CSV={csv_path}")
print(f"BENCHMARK_TRACES={trace_path}")
print(f"BENCHMARK_PLOT_REVENUE={rev_path}")
print(f"BENCHMARK_PLOT_COI={coi_path}")
print(f"BENCHMARK_PLOT_PRICE={price_path}")
_log(f"BENCHMARK_CSV={csv_path}")
_log(f"BENCHMARK_TRACES={trace_path}")
_log(f"BENCHMARK_PLOT_REVENUE={rev_path}")
_log(f"BENCHMARK_PLOT_COI={coi_path}")
_log(f"BENCHMARK_PLOT_PRICE={price_path}")
def run_cli(raw_args: list[str] | None = None):
configure_logging()
parser = argparse.ArgumentParser(description="PHANTOM benchmark orchestrator")
parser.add_argument("--project", default="capstone")
parser.add_argument("--tiers", default="static,surge,linear,qtable,ppo")