fix: logging into benchmark of wandb

This commit is contained in:
2026-03-10 14:54:44 +01:00
parent 1c2935dc87
commit 8404a88ef1

View File

@@ -559,6 +559,7 @@ def run_cli(raw_args: list[str] | None = None):
return return
tiers = _parse_list(args.tiers) tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values)
run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S") run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S")
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
compare_tag = "robust-compare" if compare_enabled else "single-mode" compare_tag = "robust-compare" if compare_enabled else "single-mode"
@@ -571,44 +572,53 @@ def run_cli(raw_args: list[str] | None = None):
run_idx = 0 run_idx = 0
for tier in tiers: for tier in tiers:
for mode_label, no_robust in modes: for mode_label, no_robust in modes:
run_idx += 1 for alpha in alpha_values:
tier_args = argparse.Namespace(**vars(args)) run_idx += 1
tier_args.tiers = tier alpha_token = (
tier_args.no_robust = bool(no_robust) f"{float(alpha):.2f}".rstrip("0").rstrip(".").replace(".", "p")
run = wandb.init( )
project=args.project, tier_args = argparse.Namespace(**vars(args))
name=f"benchmark-{tier}-{mode_label}-{run_stamp}-{run_idx}", tier_args.tiers = tier
tags=[ tier_args.alpha_values = str(float(alpha))
"benchmark", tier_args.no_robust = bool(no_robust)
compare_tag, run = wandb.init(
f"backend:{tier}", project=args.project,
f"mode:{mode_label}", name=(
], f"benchmark-{tier}-{mode_label}-a{alpha_token}-{run_stamp}-{run_idx}"
config={ ),
"run.kind": "benchmark", tags=[
"runtime/backend": tier, "benchmark",
"study/mode": mode_label, compare_tag,
"study/no_robust": float(no_robust), f"backend:{tier}",
"tiers": tier, f"mode:{mode_label}",
"alpha_values": args.alpha_values, f"alpha:{alpha_token}",
"episodes": args.episodes, ],
"total_timesteps": args.total_timesteps, config={
"lambda_coi": args.lambda_coi, "run.kind": "benchmark",
"robust_radius": args.robust_radius, "runtime/backend": tier,
"robust_points": args.robust_points, "study/mode": mode_label,
"robust_rollouts": args.robust_rollouts, "study/no_robust": float(no_robust),
"eta_ux": args.eta_ux, "study/alpha": float(alpha),
"reward_profit_weight": args.reward_profit_weight, "tiers": tier,
"learning_rate": args.learning_rate, "alpha_values": str(float(alpha)),
"device": args.device, "episodes": args.episodes,
}, "total_timesteps": args.total_timesteps,
mode="offline" if args.offline else "online", "lambda_coi": args.lambda_coi,
) "robust_radius": args.robust_radius,
try: "robust_points": args.robust_points,
_run_with_args(tier_args, compare_robust_override=False) "robust_rollouts": args.robust_rollouts,
finally: "eta_ux": args.eta_ux,
if run is not None: "reward_profit_weight": args.reward_profit_weight,
wandb.finish() "learning_rate": args.learning_rate,
"device": args.device,
},
mode="offline" if args.offline else "online",
)
try:
_run_with_args(tier_args, compare_robust_override=False)
finally:
if run is not None:
wandb.finish()
if __name__ == "__main__": if __name__ == "__main__":