From 8404a88ef10c46b574b312345d95c37e7fb46f85 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Tue, 10 Mar 2026 14:54:44 +0100 Subject: [PATCH] fix: logging into benchmark of wandb --- engine/benchmark.py | 86 +++++++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/engine/benchmark.py b/engine/benchmark.py index 0e2da26..7e0afaf 100644 --- a/engine/benchmark.py +++ b/engine/benchmark.py @@ -559,6 +559,7 @@ def run_cli(raw_args: list[str] | None = None): return tiers = _parse_list(args.tiers) + alpha_values = _parse_float_list(args.alpha_values) run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S") compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_tag = "robust-compare" if compare_enabled else "single-mode" @@ -571,44 +572,53 @@ def run_cli(raw_args: list[str] | None = None): run_idx = 0 for tier in tiers: for mode_label, no_robust in modes: - run_idx += 1 - tier_args = argparse.Namespace(**vars(args)) - tier_args.tiers = tier - tier_args.no_robust = bool(no_robust) - run = wandb.init( - project=args.project, - name=f"benchmark-{tier}-{mode_label}-{run_stamp}-{run_idx}", - tags=[ - "benchmark", - compare_tag, - f"backend:{tier}", - f"mode:{mode_label}", - ], - config={ - "run.kind": "benchmark", - "runtime/backend": tier, - "study/mode": mode_label, - "study/no_robust": float(no_robust), - "tiers": tier, - "alpha_values": args.alpha_values, - "episodes": args.episodes, - "total_timesteps": args.total_timesteps, - "lambda_coi": args.lambda_coi, - "robust_radius": args.robust_radius, - "robust_points": args.robust_points, - "robust_rollouts": args.robust_rollouts, - "eta_ux": args.eta_ux, - "reward_profit_weight": args.reward_profit_weight, - "learning_rate": args.learning_rate, - "device": args.device, - }, - mode="offline" if args.offline else "online", - ) - try: - _run_with_args(tier_args, compare_robust_override=False) - finally: - if run is not None: - wandb.finish() + for alpha in alpha_values: + run_idx += 1 + alpha_token = ( + f"{float(alpha):.2f}".rstrip("0").rstrip(".").replace(".", "p") + ) + tier_args = argparse.Namespace(**vars(args)) + tier_args.tiers = tier + tier_args.alpha_values = str(float(alpha)) + tier_args.no_robust = bool(no_robust) + run = wandb.init( + project=args.project, + name=( + f"benchmark-{tier}-{mode_label}-a{alpha_token}-{run_stamp}-{run_idx}" + ), + tags=[ + "benchmark", + compare_tag, + f"backend:{tier}", + f"mode:{mode_label}", + f"alpha:{alpha_token}", + ], + config={ + "run.kind": "benchmark", + "runtime/backend": tier, + "study/mode": mode_label, + "study/no_robust": float(no_robust), + "study/alpha": float(alpha), + "tiers": tier, + "alpha_values": str(float(alpha)), + "episodes": args.episodes, + "total_timesteps": args.total_timesteps, + "lambda_coi": args.lambda_coi, + "robust_radius": args.robust_radius, + "robust_points": args.robust_points, + "robust_rollouts": args.robust_rollouts, + "eta_ux": args.eta_ux, + "reward_profit_weight": args.reward_profit_weight, + "learning_rate": args.learning_rate, + "device": args.device, + }, + mode="offline" if args.offline else "online", + ) + try: + _run_with_args(tier_args, compare_robust_override=False) + finally: + if run is not None: + wandb.finish() if __name__ == "__main__":