fix: logging into benchmark of wandb

This commit is contained in:
2026-03-10 14:54:44 +01:00
parent 1c2935dc87
commit 8404a88ef1

View File

@@ -559,6 +559,7 @@ def run_cli(raw_args: list[str] | None = None):
return return
tiers = _parse_list(args.tiers) tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values)
run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S") run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S")
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
compare_tag = "robust-compare" if compare_enabled else "single-mode" compare_tag = "robust-compare" if compare_enabled else "single-mode"
@@ -571,26 +572,35 @@ def run_cli(raw_args: list[str] | None = None):
run_idx = 0 run_idx = 0
for tier in tiers: for tier in tiers:
for mode_label, no_robust in modes: for mode_label, no_robust in modes:
for alpha in alpha_values:
run_idx += 1 run_idx += 1
alpha_token = (
f"{float(alpha):.2f}".rstrip("0").rstrip(".").replace(".", "p")
)
tier_args = argparse.Namespace(**vars(args)) tier_args = argparse.Namespace(**vars(args))
tier_args.tiers = tier tier_args.tiers = tier
tier_args.alpha_values = str(float(alpha))
tier_args.no_robust = bool(no_robust) tier_args.no_robust = bool(no_robust)
run = wandb.init( run = wandb.init(
project=args.project, project=args.project,
name=f"benchmark-{tier}-{mode_label}-{run_stamp}-{run_idx}", name=(
f"benchmark-{tier}-{mode_label}-a{alpha_token}-{run_stamp}-{run_idx}"
),
tags=[ tags=[
"benchmark", "benchmark",
compare_tag, compare_tag,
f"backend:{tier}", f"backend:{tier}",
f"mode:{mode_label}", f"mode:{mode_label}",
f"alpha:{alpha_token}",
], ],
config={ config={
"run.kind": "benchmark", "run.kind": "benchmark",
"runtime/backend": tier, "runtime/backend": tier,
"study/mode": mode_label, "study/mode": mode_label,
"study/no_robust": float(no_robust), "study/no_robust": float(no_robust),
"study/alpha": float(alpha),
"tiers": tier, "tiers": tier,
"alpha_values": args.alpha_values, "alpha_values": str(float(alpha)),
"episodes": args.episodes, "episodes": args.episodes,
"total_timesteps": args.total_timesteps, "total_timesteps": args.total_timesteps,
"lambda_coi": args.lambda_coi, "lambda_coi": args.lambda_coi,