From b1f583be39d0a4831cb99e3b5988e213f90aa325 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 12 Mar 2026 09:16:50 +0100 Subject: [PATCH] nightly benchmark run configureation --- engine/benchmark.py | 6 +-- scripts/ray_distributed_train.py | 78 ++++++++++++++++++++++++++++---- submit_ray_job.sh | 39 +++++++++++++++- 3 files changed, 109 insertions(+), 14 deletions(-) diff --git a/engine/benchmark.py b/engine/benchmark.py index 47fb780..fc0205f 100644 --- a/engine/benchmark.py +++ b/engine/benchmark.py @@ -7,7 +7,7 @@ import sys import argparse import json import logging -from datetime import datetime, UTC +from datetime import datetime, timezone from pathlib import Path # clear stale TPU locks on startup @@ -449,7 +449,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - stamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") csv_path = out_dir / f"benchmark_{stamp}.csv" trace_path = out_dir / f"benchmark_traces_{stamp}.json" df.to_csv(csv_path, index=False) @@ -580,7 +580,7 @@ def run_cli(raw_args: list[str] | None = None): tiers = _parse_list(args.tiers) alpha_values = _parse_float_list(args.alpha_values) - run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S") + run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S") compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_tag = "robust-compare" if compare_enabled else "single-mode" modes = ( diff --git a/scripts/ray_distributed_train.py b/scripts/ray_distributed_train.py index f918f33..7e2fc23 100644 --- a/scripts/ray_distributed_train.py +++ b/scripts/ray_distributed_train.py @@ -15,6 +15,15 @@ def _has_flag(tokens: list[str], name: str) -> bool: return any(tok == name or tok.startswith(f"{name}=") for tok in tokens) +def _entry_tokens(run_kind: str, entry_args: str) -> list[str]: + tokens = shlex.split(entry_args) + if run_kind == "benchmark" and not ( + _has_flag(tokens, "--run-kind") or _has_flag(tokens, "--run-mode") + ): + return ["--run-kind", "benchmark", *tokens] + return tokens + + def _alive_node_ips() -> list[str]: seen: set[str] = set() ips: list[str] = [] @@ -33,13 +42,18 @@ def _alive_node_ips() -> list[str]: def _train_on_node( *, root: str, - train_args: str, + run_kind: str, + entry_args: str, rank: int, world_size: int, coordinator_ip: str, coordinator_port: int, base_seed: int, run_group: str, + compare_robust: bool, + output_root: str, + wandb_entity: str, + wandb_project: str, sync_jax: bool, ) -> int: env = dict(os.environ) @@ -53,6 +67,12 @@ def _train_on_node( env["JAX_PLATFORMS"] = requested_platform # Keep each train process in single-host mode to avoid accidental global stalls. env["CLOUD_TPU_TASK_ID"] = "0" + if run_kind == "benchmark": + env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0" + if wandb_entity: + env["WANDB_ENTITY"] = wandb_entity + if wandb_project: + env["WANDB_PROJECT"] = wandb_project cwd = str(Path(root)) @@ -75,30 +95,62 @@ def _train_on_node( [sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True ) - tokens = shlex.split(train_args) + tokens = _entry_tokens(run_kind, entry_args) + seed = int(base_seed + rank) if not _has_flag(tokens, "--seed"): - tokens.extend(["--seed", str(base_seed + rank)]) - if not _has_flag(tokens, "--group"): + tokens.extend(["--seed", str(seed)]) + + if run_kind == "train" and not _has_flag(tokens, "--group"): tokens.extend(["--group", run_group]) + if ( + run_kind == "benchmark" + and output_root + and not _has_flag(tokens, "--output-dir") + ): + out_dir = Path(output_root) / f"rank_{rank}" / f"seed_{seed}" + out_dir.parent.mkdir(parents=True, exist_ok=True) + tokens.extend(["--output-dir", str(out_dir)]) + cmd = [sys.executable, "-m", "engine.train", *tokens] + print( + { + "rank": int(rank), + "run_kind": run_kind, + "seed": int(seed), + "compare_robust": bool(compare_robust), + "wandb_entity": str(env.get("WANDB_ENTITY", "")), + "wandb_project": str(env.get("WANDB_PROJECT", "")), + "command": " ".join(cmd), + } + ) proc = subprocess.run(cmd, cwd=cwd, env=env) return int(proc.returncode) def main() -> None: parser = argparse.ArgumentParser( - description="Launch one train run per Ray TPU node" + description="Launch one train/benchmark run per Ray TPU node" ) - parser.add_argument("--train-args", type=str, required=True) + parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train") + parser.add_argument("--entry-args", type=str, default="") + parser.add_argument("--train-args", type=str, default="") parser.add_argument("--num-nodes", type=int, default=0) parser.add_argument("--tpu-per-task", type=float, default=8.0) parser.add_argument("--base-seed", type=int, default=42) parser.add_argument("--sync-jax", action="store_true") parser.add_argument("--coordinator-port", type=int, default=12355) parser.add_argument("--run-group", type=str, default="") + parser.add_argument("--compare-robust", action="store_true") + parser.add_argument("--output-root", type=str, default="") + parser.add_argument("--wandb-entity", type=str, default="") + parser.add_argument("--wandb-project", type=str, default="") args = parser.parse_args() + entry_args = str(args.entry_args or args.train_args).strip() + if not entry_args: + raise ValueError("--entry-args (or legacy --train-args) is required") + ray.init(address="auto") node_ips = _alive_node_ips() @@ -118,8 +170,11 @@ def main() -> None: "nodes": node_ips, "world_size": world_size, "coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}", - "train_args": args.train_args, + "run_kind": str(args.run_kind), + "entry_args": entry_args, "run_group": run_group, + "compare_robust": bool(args.compare_robust), + "output_root": str(args.output_root), } ) @@ -130,14 +185,19 @@ def main() -> None: futures.append( _train_on_node.options(resources=resources).remote( root=root, - train_args=args.train_args, + run_kind=str(args.run_kind), + entry_args=entry_args, rank=rank, world_size=world_size, coordinator_ip=coordinator_ip, coordinator_port=int(args.coordinator_port), base_seed=int(args.base_seed), run_group=run_group, - sync_jax=bool(args.sync_jax), + compare_robust=bool(args.compare_robust), + output_root=str(args.output_root), + wandb_entity=str(args.wandb_entity), + wandb_project=str(args.wandb_project), + sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"), ) ) diff --git a/submit_ray_job.sh b/submit_ray_job.sh index 11775d6..c2f9709 100755 --- a/submit_ray_job.sh +++ b/submit_ray_job.sh @@ -3,6 +3,7 @@ # Modes: # RAY_MODE=single -> one run (default) # RAY_MODE=distributed -> one run per TPU node (experimental) +# RAY_MODE=benchmark -> one benchmark run per TPU node (overnight) set -euo pipefail @@ -27,6 +28,9 @@ env = dotenv_values(".env") # Filter out empty/None values env_vars = {k: v for k, v in env.items() if v} env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0")) +for k in ("WANDB_ENTITY", "WANDB_PROJECT", "PHANTOM_BENCHMARK_COMPARE_ROBUST"): + if os.getenv(k): + env_vars[k] = os.getenv(k) print(json.dumps({ "pip": [ @@ -38,7 +42,8 @@ print(json.dumps({ "pandas", "pydantic", "graphviz", - "huggingface_hub" + "huggingface_hub", + "matplotlib" ], "env_vars": env_vars })) @@ -46,12 +51,22 @@ print(json.dumps({ RAY_MODE="${RAY_MODE:-single}" TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}" +BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}" + +SUBMIT_ARGS=() +if [ "${RAY_NO_WAIT:-0}" = "1" ]; then + SUBMIT_ARGS+=(--no-wait) +fi +if [ -n "${SUBMISSION_ID:-}" ]; then + SUBMIT_ARGS+=(--submission-id "$SUBMISSION_ID") +fi COMMON_ARGS=( job submit --address http://localhost:8265 --working-dir "$ROOT" --runtime-env-json "$RUNTIME_ENV_JSON" + "${SUBMIT_ARGS[@]}" -- ) @@ -77,5 +92,25 @@ if [ "$RAY_MODE" = "distributed" ]; then exit 0 fi -echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single' or 'distributed')." >&2 +if [ "$RAY_MODE" = "benchmark" ]; then + DIST_ARGS=( + python + scripts/ray_distributed_train.py + --run-kind benchmark + --entry-args "$BENCHMARK_ARGS" + --num-nodes "${NUM_NODES:-4}" + --tpu-per-task "${TPU_PER_TASK:-8}" + --base-seed "${BASE_SEED:-42}" + --output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}" + --wandb-entity "${WANDB_ENTITY:-lusiana}" + --wandb-project "${WANDB_PROJECT:-capstone_tpu}" + ) + if [ "${COMPARE_ROBUST:-1}" = "1" ]; then + DIST_ARGS+=(--compare-robust) + fi + "$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}" + exit 0 +fi + +echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', or 'benchmark')." >&2 exit 1