nightly benchmark run configureation

This commit is contained in:
2026-03-12 09:16:50 +01:00
parent 22e50aac4a
commit b1f583be39
3 changed files with 109 additions and 14 deletions

View File

@@ -7,7 +7,7 @@ import sys
import argparse import argparse
import json import json
import logging import logging
from datetime import datetime, UTC from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
# clear stale TPU locks on startup # clear stale TPU locks on startup
@@ -449,7 +449,7 @@ def _run_with_args(args, compare_robust_override: bool | None = None):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
stamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
csv_path = out_dir / f"benchmark_{stamp}.csv" csv_path = out_dir / f"benchmark_{stamp}.csv"
trace_path = out_dir / f"benchmark_traces_{stamp}.json" trace_path = out_dir / f"benchmark_traces_{stamp}.json"
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
@@ -580,7 +580,7 @@ def run_cli(raw_args: list[str] | None = None):
tiers = _parse_list(args.tiers) tiers = _parse_list(args.tiers)
alpha_values = _parse_float_list(args.alpha_values) alpha_values = _parse_float_list(args.alpha_values)
run_stamp = datetime.now(UTC).strftime("%m%d-%H%M%S") run_stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M%S")
compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST")) compare_enabled = _truthy(os.environ.get("PHANTOM_BENCHMARK_COMPARE_ROBUST"))
compare_tag = "robust-compare" if compare_enabled else "single-mode" compare_tag = "robust-compare" if compare_enabled else "single-mode"
modes = ( modes = (

View File

@@ -15,6 +15,15 @@ def _has_flag(tokens: list[str], name: str) -> bool:
return any(tok == name or tok.startswith(f"{name}=") for tok in tokens) return any(tok == name or tok.startswith(f"{name}=") for tok in tokens)
def _entry_tokens(run_kind: str, entry_args: str) -> list[str]:
tokens = shlex.split(entry_args)
if run_kind == "benchmark" and not (
_has_flag(tokens, "--run-kind") or _has_flag(tokens, "--run-mode")
):
return ["--run-kind", "benchmark", *tokens]
return tokens
def _alive_node_ips() -> list[str]: def _alive_node_ips() -> list[str]:
seen: set[str] = set() seen: set[str] = set()
ips: list[str] = [] ips: list[str] = []
@@ -33,13 +42,18 @@ def _alive_node_ips() -> list[str]:
def _train_on_node( def _train_on_node(
*, *,
root: str, root: str,
train_args: str, run_kind: str,
entry_args: str,
rank: int, rank: int,
world_size: int, world_size: int,
coordinator_ip: str, coordinator_ip: str,
coordinator_port: int, coordinator_port: int,
base_seed: int, base_seed: int,
run_group: str, run_group: str,
compare_robust: bool,
output_root: str,
wandb_entity: str,
wandb_project: str,
sync_jax: bool, sync_jax: bool,
) -> int: ) -> int:
env = dict(os.environ) env = dict(os.environ)
@@ -53,6 +67,12 @@ def _train_on_node(
env["JAX_PLATFORMS"] = requested_platform env["JAX_PLATFORMS"] = requested_platform
# Keep each train process in single-host mode to avoid accidental global stalls. # Keep each train process in single-host mode to avoid accidental global stalls.
env["CLOUD_TPU_TASK_ID"] = "0" env["CLOUD_TPU_TASK_ID"] = "0"
if run_kind == "benchmark":
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
if wandb_entity:
env["WANDB_ENTITY"] = wandb_entity
if wandb_project:
env["WANDB_PROJECT"] = wandb_project
cwd = str(Path(root)) cwd = str(Path(root))
@@ -75,30 +95,62 @@ def _train_on_node(
[sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True [sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True
) )
tokens = shlex.split(train_args) tokens = _entry_tokens(run_kind, entry_args)
seed = int(base_seed + rank)
if not _has_flag(tokens, "--seed"): if not _has_flag(tokens, "--seed"):
tokens.extend(["--seed", str(base_seed + rank)]) tokens.extend(["--seed", str(seed)])
if not _has_flag(tokens, "--group"):
if run_kind == "train" and not _has_flag(tokens, "--group"):
tokens.extend(["--group", run_group]) tokens.extend(["--group", run_group])
if (
run_kind == "benchmark"
and output_root
and not _has_flag(tokens, "--output-dir")
):
out_dir = Path(output_root) / f"rank_{rank}" / f"seed_{seed}"
out_dir.parent.mkdir(parents=True, exist_ok=True)
tokens.extend(["--output-dir", str(out_dir)])
cmd = [sys.executable, "-m", "engine.train", *tokens] cmd = [sys.executable, "-m", "engine.train", *tokens]
print(
{
"rank": int(rank),
"run_kind": run_kind,
"seed": int(seed),
"compare_robust": bool(compare_robust),
"wandb_entity": str(env.get("WANDB_ENTITY", "")),
"wandb_project": str(env.get("WANDB_PROJECT", "")),
"command": " ".join(cmd),
}
)
proc = subprocess.run(cmd, cwd=cwd, env=env) proc = subprocess.run(cmd, cwd=cwd, env=env)
return int(proc.returncode) return int(proc.returncode)
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Launch one train run per Ray TPU node" description="Launch one train/benchmark run per Ray TPU node"
) )
parser.add_argument("--train-args", type=str, required=True) parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train")
parser.add_argument("--entry-args", type=str, default="")
parser.add_argument("--train-args", type=str, default="")
parser.add_argument("--num-nodes", type=int, default=0) parser.add_argument("--num-nodes", type=int, default=0)
parser.add_argument("--tpu-per-task", type=float, default=8.0) parser.add_argument("--tpu-per-task", type=float, default=8.0)
parser.add_argument("--base-seed", type=int, default=42) parser.add_argument("--base-seed", type=int, default=42)
parser.add_argument("--sync-jax", action="store_true") parser.add_argument("--sync-jax", action="store_true")
parser.add_argument("--coordinator-port", type=int, default=12355) parser.add_argument("--coordinator-port", type=int, default=12355)
parser.add_argument("--run-group", type=str, default="") parser.add_argument("--run-group", type=str, default="")
parser.add_argument("--compare-robust", action="store_true")
parser.add_argument("--output-root", type=str, default="")
parser.add_argument("--wandb-entity", type=str, default="")
parser.add_argument("--wandb-project", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
entry_args = str(args.entry_args or args.train_args).strip()
if not entry_args:
raise ValueError("--entry-args (or legacy --train-args) is required")
ray.init(address="auto") ray.init(address="auto")
node_ips = _alive_node_ips() node_ips = _alive_node_ips()
@@ -118,8 +170,11 @@ def main() -> None:
"nodes": node_ips, "nodes": node_ips,
"world_size": world_size, "world_size": world_size,
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}", "coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
"train_args": args.train_args, "run_kind": str(args.run_kind),
"entry_args": entry_args,
"run_group": run_group, "run_group": run_group,
"compare_robust": bool(args.compare_robust),
"output_root": str(args.output_root),
} }
) )
@@ -130,14 +185,19 @@ def main() -> None:
futures.append( futures.append(
_train_on_node.options(resources=resources).remote( _train_on_node.options(resources=resources).remote(
root=root, root=root,
train_args=args.train_args, run_kind=str(args.run_kind),
entry_args=entry_args,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
coordinator_ip=coordinator_ip, coordinator_ip=coordinator_ip,
coordinator_port=int(args.coordinator_port), coordinator_port=int(args.coordinator_port),
base_seed=int(args.base_seed), base_seed=int(args.base_seed),
run_group=run_group, run_group=run_group,
sync_jax=bool(args.sync_jax), compare_robust=bool(args.compare_robust),
output_root=str(args.output_root),
wandb_entity=str(args.wandb_entity),
wandb_project=str(args.wandb_project),
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
) )
) )

View File

@@ -3,6 +3,7 @@
# Modes: # Modes:
# RAY_MODE=single -> one run (default) # RAY_MODE=single -> one run (default)
# RAY_MODE=distributed -> one run per TPU node (experimental) # RAY_MODE=distributed -> one run per TPU node (experimental)
# RAY_MODE=benchmark -> one benchmark run per TPU node (overnight)
set -euo pipefail set -euo pipefail
@@ -27,6 +28,9 @@ env = dotenv_values(".env")
# Filter out empty/None values # Filter out empty/None values
env_vars = {k: v for k, v in env.items() if v} env_vars = {k: v for k, v in env.items() if v}
env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0")) env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0"))
for k in ("WANDB_ENTITY", "WANDB_PROJECT", "PHANTOM_BENCHMARK_COMPARE_ROBUST"):
if os.getenv(k):
env_vars[k] = os.getenv(k)
print(json.dumps({ print(json.dumps({
"pip": [ "pip": [
@@ -38,7 +42,8 @@ print(json.dumps({
"pandas", "pandas",
"pydantic", "pydantic",
"graphviz", "graphviz",
"huggingface_hub" "huggingface_hub",
"matplotlib"
], ],
"env_vars": env_vars "env_vars": env_vars
})) }))
@@ -46,12 +51,22 @@ print(json.dumps({
RAY_MODE="${RAY_MODE:-single}" RAY_MODE="${RAY_MODE:-single}"
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}" TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}"
BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}"
SUBMIT_ARGS=()
if [ "${RAY_NO_WAIT:-0}" = "1" ]; then
SUBMIT_ARGS+=(--no-wait)
fi
if [ -n "${SUBMISSION_ID:-}" ]; then
SUBMIT_ARGS+=(--submission-id "$SUBMISSION_ID")
fi
COMMON_ARGS=( COMMON_ARGS=(
job submit job submit
--address http://localhost:8265 --address http://localhost:8265
--working-dir "$ROOT" --working-dir "$ROOT"
--runtime-env-json "$RUNTIME_ENV_JSON" --runtime-env-json "$RUNTIME_ENV_JSON"
"${SUBMIT_ARGS[@]}"
-- --
) )
@@ -77,5 +92,25 @@ if [ "$RAY_MODE" = "distributed" ]; then
exit 0 exit 0
fi fi
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single' or 'distributed')." >&2 if [ "$RAY_MODE" = "benchmark" ]; then
DIST_ARGS=(
python
scripts/ray_distributed_train.py
--run-kind benchmark
--entry-args "$BENCHMARK_ARGS"
--num-nodes "${NUM_NODES:-4}"
--tpu-per-task "${TPU_PER_TASK:-8}"
--base-seed "${BASE_SEED:-42}"
--output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}"
--wandb-entity "${WANDB_ENTITY:-lusiana}"
--wandb-project "${WANDB_PROJECT:-capstone_tpu}"
)
if [ "${COMPARE_ROBUST:-1}" = "1" ]; then
DIST_ARGS+=(--compare-robust)
fi
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
exit 0
fi
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', or 'benchmark')." >&2
exit 1 exit 1