From 88155d22a75fd23f2ba415afe8de3066a4cfd377 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 12 Mar 2026 12:48:52 +0100 Subject: [PATCH] chore: refactor for sweeps and IP configs --- docker-compose.yml | 4 + docker/TPUWatchdog.dockerfile | 48 +++- scripts/ray_distributed_train.py | 418 +++++++++++++++++++++++++++++-- submit_ray_job.sh | 117 ++++++++- tpu_orchestration/watchdog.sh | 13 +- 5 files changed, 566 insertions(+), 34 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index c00f4e1..24961c5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,9 +10,13 @@ services: - HF_TOKEN=${HF_TOKEN} - WANDB_API_KEY=${WANDB_API_KEY} - GITHUB_TOKEN=${GITHUB_TOKEN} + - GOOGLE_APPLICATION_CREDENTIALS=/secrets/gcp-sa.json + - GCP_ACCOUNT=${GCP_ACCOUNT:-} + - WATCHDOG_CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-v6e_*.conf} - CLOUDSDK_CONFIG=/.config/gcloud volumes: - ~/.config/gcloud:/.config/gcloud:rw + - ./secrets/gcp-sa.json:/secrets/gcp-sa.json:ro tensorboard-rl: image: tensorflow/tensorflow:latest diff --git a/docker/TPUWatchdog.dockerfile b/docker/TPUWatchdog.dockerfile index 8299171..66c0c3f 100644 --- a/docker/TPUWatchdog.dockerfile +++ b/docker/TPUWatchdog.dockerfile @@ -35,25 +35,55 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT if [ "$CRED_TYPE" = "service_account" ]; then echo "Authenticating gcloud using service account key..." gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS" - - # Extract project ID from the key file - PROJECT_ID=$(jq -r '.project_id' "$GOOGLE_APPLICATION_CREDENTIALS") - if [ -n "$PROJECT_ID" ] && [ "$PROJECT_ID" != "null" ]; then - gcloud config set project "$PROJECT_ID" - echo "Set project to $PROJECT_ID" + + if [ -z "$PROJECT_ID" ]; then + PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS") fi + elif [ "$CRED_TYPE" = "authorized_user" ]; then + echo "Authenticating gcloud using authorized_user refresh token..." + + AUTH_ACCOUNT="$GCP_ACCOUNT" + if [ -z "$AUTH_ACCOUNT" ]; then + AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS") + fi + if [ -z "$AUTH_ACCOUNT" ]; then + AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true) + fi + + REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS") + if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then + echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token." + exit 1 + fi + + gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN" else - echo "Note: Using application default credentials or mounted gcloud config..." + echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config." fi else echo "Note: Assuming gcloud config is mounted from host." fi +if [ -n "$PROJECT_ID" ]; then + gcloud config set project "$PROJECT_ID" + echo "Set project to $PROJECT_ID" +fi + # Run the watchdogs in the background using bash instead of tmux # Tmux needs a TTY to attach properly which we might not have in docker # Stagger startups by 15s to prevent simultaneous TPU creation quota hits +CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-"*.conf"} +shopt -s nullglob +CONFIGS=(/app/tpu_orchestration/configs/$CONFIG_PATTERN) + +if [ ${#CONFIGS[@]} -eq 0 ]; then + echo "Error: no watchdog configs matched pattern '$CONFIG_PATTERN'." + exit 1 +fi + +echo "Using watchdog config pattern: $CONFIG_PATTERN" DELAY=0 -for conf in /app/tpu_orchestration/configs/*.conf; do +for conf in "${CONFIGS[@]}"; do echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)" (sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") & DELAY=$((DELAY + 15)) @@ -67,4 +97,4 @@ EOF RUN chmod +x /app/entrypoint.sh -CMD ["/app/entrypoint.sh"] \ No newline at end of file +CMD ["/app/entrypoint.sh"] diff --git a/scripts/ray_distributed_train.py b/scripts/ray_distributed_train.py index 7e2fc23..3395a8f 100644 --- a/scripts/ray_distributed_train.py +++ b/scripts/ray_distributed_train.py @@ -1,14 +1,18 @@ from __future__ import annotations import argparse +import contextlib +import concurrent.futures import os import shlex import subprocess import sys +import threading import time from pathlib import Path import ray +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy def _has_flag(tokens: list[str], name: str) -> bool: @@ -24,18 +28,301 @@ def _entry_tokens(run_kind: str, entry_args: str) -> list[str]: return tokens -def _alive_node_ips() -> list[str]: +def _get_flag_value(tokens: list[str], name: str, default: str = "") -> str: + for idx, tok in enumerate(tokens): + if tok == name and idx + 1 < len(tokens): + return str(tokens[idx + 1]) + if tok.startswith(f"{name}="): + return str(tok.split("=", 1)[1]) + return str(default) + + +def _set_flag_value(tokens: list[str], name: str, value: str) -> list[str]: + updated: list[str] = [] + replaced = False + idx = 0 + while idx < len(tokens): + tok = tokens[idx] + if tok == name: + replaced = True + updated.extend([name, str(value)]) + idx += 2 + continue + if tok.startswith(f"{name}="): + replaced = True + updated.append(f"{name}={value}") + idx += 1 + continue + updated.append(tok) + idx += 1 + if not replaced: + updated.extend([name, str(value)]) + return updated + + +def _remove_flag(tokens: list[str], name: str) -> list[str]: + updated: list[str] = [] + idx = 0 + while idx < len(tokens): + tok = tokens[idx] + if tok == name: + idx += 1 + continue + if tok.startswith(f"{name}="): + idx += 1 + continue + updated.append(tok) + idx += 1 + return updated + + +def _csv_values(raw: str) -> list[str]: + return [piece.strip() for piece in str(raw).split(",") if piece.strip()] + + +def _alpha_token(alpha: str) -> str: + return str(alpha).replace(".", "p").replace("-", "m") + + +def _truthy(value: str | bool | None) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _alive_nodes() -> list[tuple[str, str]]: seen: set[str] = set() - ips: list[str] = [] + nodes: list[tuple[str, str]] = [] for node in ray.nodes(): if not bool(node.get("Alive", False)): continue + node_id = str(node.get("NodeID", "")).strip() ip = str(node.get("NodeManagerAddress", "")).strip() - if not ip or ip in seen: + if not node_id or not ip or node_id in seen: continue - seen.add(ip) - ips.append(ip) - return sorted(ips) + seen.add(node_id) + nodes.append((node_id, ip)) + return sorted(nodes, key=lambda item: (item[1], item[0])) + + +def _benchmark_cells( + tokens: list[str], *, compare_robust: bool +) -> list[tuple[str, str, str, bool]]: + tiers = _csv_values( + _get_flag_value(tokens, "--tiers", "static,surge,linear,qtable,ppo") + ) + alphas = _csv_values(_get_flag_value(tokens, "--alpha-values", "0.0,0.3,0.6")) + base_no_robust = _has_flag(tokens, "--no-robust") + if compare_robust: + modes = [("robust", False), ("no_robust", True)] + else: + modes = [("no_robust", True)] if base_no_robust else [("robust", False)] + return [ + (tier, alpha, mode_label, no_robust) + for tier in tiers + for alpha in alphas + for mode_label, no_robust in modes + ] + + +def _thread_limited_env(env: dict[str, str], threads: int) -> dict[str, str]: + bounded = dict(env) + n = str(max(1, int(threads))) + for key in ( + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", + "BLIS_NUM_THREADS", + ): + bounded[key] = n + return bounded + + +@contextlib.contextmanager +def _semaphore_guard(semaphore: threading.Semaphore | None): + if semaphore is None: + yield + return + semaphore.acquire() + try: + yield + finally: + semaphore.release() + + +def _run_benchmark_cells_parallel( + *, + root: str, + env: dict[str, str], + base_tokens: list[str], + compare_robust: bool, + inner_workers: int, + inner_threads: int, + max_heavy_workers: int, + rank: int, +) -> int: + cells = _benchmark_cells(base_tokens, compare_robust=compare_robust) + if not cells: + return 0 + + cwd = str(Path(root)) + base_out = _get_flag_value(base_tokens, "--output-dir", "engine/studies/results") + max_workers = max(1, min(int(inner_workers), len(cells))) + heavy_tiers = {"ppo", "a2c", "dqn"} + heavy_limit = max(1, int(max_heavy_workers)) + heavy_sem = threading.Semaphore(heavy_limit) + print( + { + "rank": int(rank), + "benchmark_cells": len(cells), + "inner_workers": int(max_workers), + "inner_threads": int(max(1, int(inner_threads))), + "heavy_limit": int(heavy_limit), + } + ) + + def _run_cell( + index: int, + total: int, + tier: str, + alpha: str, + mode_label: str, + no_robust: bool, + ) -> tuple[str, str, str, int]: + tokens = list(base_tokens) + tokens = _set_flag_value(tokens, "--tiers", tier) + tokens = _set_flag_value(tokens, "--alpha-values", alpha) + if no_robust: + if not _has_flag(tokens, "--no-robust"): + tokens.append("--no-robust") + else: + tokens = _remove_flag(tokens, "--no-robust") + + cell_out = ( + Path(base_out) + / f"tier_{tier}" + / f"mode_{mode_label}" + / f"alpha_{_alpha_token(alpha)}" + ) + tokens = _set_flag_value(tokens, "--output-dir", str(cell_out)) + cmd = [sys.executable, "-m", "engine.train", *tokens] + cell_env = _thread_limited_env(env, int(inner_threads)) + cell_env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "0" + print( + { + "rank": int(rank), + "cell": f"{index}/{total}", + "tier": tier, + "mode": mode_label, + "alpha": alpha, + "command": " ".join(cmd), + } + ) + heavy_guard = heavy_sem if str(tier).lower() in heavy_tiers else None + with _semaphore_guard(heavy_guard): + proc = subprocess.run(cmd, cwd=cwd, env=cell_env) + return tier, alpha, mode_label, int(proc.returncode) + + failures: list[tuple[str, str, str, int]] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit(_run_cell, idx, len(cells), tier, alpha, mode_label, no_robust) + for idx, (tier, alpha, mode_label, no_robust) in enumerate(cells, start=1) + ] + for fut in concurrent.futures.as_completed(futures): + tier, alpha, mode_label, code = fut.result() + if code != 0: + failures.append((tier, alpha, mode_label, code)) + + if failures: + print({"rank": int(rank), "benchmark_failures": failures}) + return 1 + return 0 + + +def _run_sweep_agents_parallel( + *, + root: str, + env: dict[str, str], + base_tokens: list[str], + run_kind: str, + rank: int, + agents_per_node: int, + agent_count: int, + inner_threads: int, + tpu_agent_slots: int, +) -> int: + total = max(1, int(agents_per_node)) + cwd = str(Path(root)) + wants_tpu = str(env.get("JAX_PLATFORMS", "")).strip().lower() == "tpu" + tpu_slots = max(0, int(tpu_agent_slots)) + print( + { + "rank": int(rank), + "sweep_agents": int(total), + "agent_count": int(agent_count), + "inner_threads": int(max(1, int(inner_threads))), + "jax_platform": str(env.get("JAX_PLATFORMS", "")), + "tpu_agent_slots": int(tpu_slots), + } + ) + + def _run_agent(slot: int) -> int: + tokens = list(base_tokens) + if int(agent_count) > 0 and not _has_flag(tokens, "--count"): + tokens.extend(["--count", str(int(agent_count))]) + + if _has_flag(tokens, "--group"): + base_group = _get_flag_value(tokens, "--group", "ray-sweep") + tokens = _set_flag_value(tokens, "--group", f"{base_group}-a{slot}") + + if run_kind == "benchmark": + out_dir = _get_flag_value(tokens, "--output-dir", "engine/studies/results") + tokens = _set_flag_value( + tokens, "--output-dir", str(Path(out_dir) / f"agent_{slot}") + ) + if run_kind == "train": + model_dir = _get_flag_value(tokens, "--model-dir", "engine/models") + tokens = _set_flag_value( + tokens, "--model-dir", str(Path(model_dir) / f"agent_{slot}") + ) + + cmd = [sys.executable, "-m", "engine.train", *tokens] + agent_env = _thread_limited_env(env, int(inner_threads)) + if wants_tpu and tpu_slots > 0 and int(slot) > tpu_slots: + agent_env["JAX_PLATFORMS"] = "cpu" + agent_env["JAX_PLATFORM_NAME"] = "cpu" + agent_env["PHANTOM_SWEEP_AGENT_SLOT"] = str(int(slot)) + print( + { + "rank": int(rank), + "agent_slot": int(slot), + "jax_platform": str(agent_env.get("JAX_PLATFORMS", "")), + "command": " ".join(cmd), + } + ) + proc = subprocess.run(cmd, cwd=cwd, env=agent_env) + return int(proc.returncode) + + failures: list[tuple[int, int]] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=total) as pool: + future_map = { + pool.submit(_run_agent, slot): slot for slot in range(1, total + 1) + } + for future in concurrent.futures.as_completed(future_map): + slot = int(future_map[future]) + code = int(future.result()) + if code != 0: + failures.append((slot, code)) + + if failures: + print({"rank": int(rank), "sweep_failures": failures}) + return 1 + return 0 @ray.remote(max_retries=0) @@ -44,6 +331,8 @@ def _train_on_node( root: str, run_kind: str, entry_args: str, + node_id: str, + node_ip: str, rank: int, world_size: int, coordinator_ip: str, @@ -54,17 +343,32 @@ def _train_on_node( output_root: str, wandb_entity: str, wandb_project: str, + agents_per_node: int, + agent_count: int, + inner_workers: int, + inner_threads: int, + max_heavy_workers: int, sync_jax: bool, ) -> int: env = dict(os.environ) env["PYTHONUNBUFFERED"] = "1" requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower() - if world_size > 1 and requested_platform == "tpu": + allow_multi_node_tpu = _truthy(env.get("PHANTOM_ALLOW_MULTI_NODE_TPU")) + if world_size > 1 and requested_platform == "tpu" and not allow_multi_node_tpu: requested_platform = "cpu" print( - "PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs" + "PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs " + "(set PHANTOM_ALLOW_MULTI_NODE_TPU=1 to keep TPU for JAX workloads)" + ) + elif world_size > 1 and requested_platform == "tpu" and allow_multi_node_tpu: + print( + "PHANTOM_DISTRIBUTED_NOTE: keeping JAX_PLATFORMS=tpu in multi-node mixed mode" ) env["JAX_PLATFORMS"] = requested_platform + if requested_platform == "cpu": + env["JAX_PLATFORM_NAME"] = "cpu" + else: + env.pop("JAX_PLATFORM_NAME", None) # Keep each train process in single-host mode to avoid accidental global stalls. env["CLOUD_TPU_TASK_ID"] = "0" if run_kind == "benchmark": @@ -96,13 +400,29 @@ def _train_on_node( ) tokens = _entry_tokens(run_kind, entry_args) + is_sweep_agent = _has_flag(tokens, "--sweep-agent") seed = int(base_seed + rank) - if not _has_flag(tokens, "--seed"): + if not is_sweep_agent and not _has_flag(tokens, "--seed"): tokens.extend(["--seed", str(seed)]) if run_kind == "train" and not _has_flag(tokens, "--group"): tokens.extend(["--group", run_group]) + if is_sweep_agent and int(agent_count) > 0 and not _has_flag(tokens, "--count"): + tokens.extend(["--count", str(int(agent_count))]) + + try: + tpu_agent_slots = int( + str( + env.get( + "PHANTOM_TPU_AGENT_SLOTS", + "1" if requested_platform == "tpu" else "0", + ) + ).strip() + ) + except ValueError: + tpu_agent_slots = 1 if requested_platform == "tpu" else 0 + if ( run_kind == "benchmark" and output_root @@ -112,9 +432,36 @@ def _train_on_node( out_dir.parent.mkdir(parents=True, exist_ok=True) tokens.extend(["--output-dir", str(out_dir)]) + if is_sweep_agent and int(agents_per_node) > 1: + return _run_sweep_agents_parallel( + root=root, + env=env, + base_tokens=tokens, + run_kind=run_kind, + rank=rank, + agents_per_node=int(agents_per_node), + agent_count=int(agent_count), + inner_threads=int(inner_threads), + tpu_agent_slots=int(max(0, tpu_agent_slots)), + ) + + if run_kind == "benchmark" and int(inner_workers) > 1 and not is_sweep_agent: + return _run_benchmark_cells_parallel( + root=root, + env=env, + base_tokens=tokens, + compare_robust=bool(compare_robust), + inner_workers=int(inner_workers), + inner_threads=int(inner_threads), + max_heavy_workers=int(max_heavy_workers), + rank=rank, + ) + cmd = [sys.executable, "-m", "engine.train", *tokens] print( { + "node_id": node_id, + "node_ip": node_ip, "rank": int(rank), "run_kind": run_kind, "seed": int(seed), @@ -124,7 +471,9 @@ def _train_on_node( "command": " ".join(cmd), } ) - proc = subprocess.run(cmd, cwd=cwd, env=env) + proc = subprocess.run( + cmd, cwd=cwd, env=_thread_limited_env(env, int(inner_threads)) + ) return int(proc.returncode) @@ -145,6 +494,12 @@ def main() -> None: parser.add_argument("--output-root", type=str, default="") parser.add_argument("--wandb-entity", type=str, default="") parser.add_argument("--wandb-project", type=str, default="") + parser.add_argument("--agents-per-node", type=int, default=1) + parser.add_argument("--agent-count", type=int, default=0) + parser.add_argument("--inner-workers", type=int, default=1) + parser.add_argument("--inner-threads", type=int, default=1) + parser.add_argument("--max-heavy-workers", type=int, default=2) + parser.add_argument("--worker-cpus", type=float, default=1.0) args = parser.parse_args() entry_args = str(args.entry_args or args.train_args).strip() @@ -153,21 +508,24 @@ def main() -> None: ray.init(address="auto") - node_ips = _alive_node_ips() - if not node_ips: + node_entries = _alive_nodes() + if not node_entries: raise RuntimeError("No alive Ray nodes found") requested = int(args.num_nodes) if requested > 0: - node_ips = node_ips[:requested] + node_entries = node_entries[:requested] - world_size = len(node_ips) - coordinator_ip = node_ips[0] + world_size = len(node_entries) + coordinator_ip = node_entries[0][1] run_group = args.run_group or f"ray-dist-{int(time.time())}" print( { - "nodes": node_ips, + "nodes": [ + {"node_id": node_id, "node_ip": node_ip} + for node_id, node_ip in node_entries + ], "world_size": world_size, "coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}", "run_kind": str(args.run_kind), @@ -175,18 +533,35 @@ def main() -> None: "run_group": run_group, "compare_robust": bool(args.compare_robust), "output_root": str(args.output_root), + "agents_per_node": int(args.agents_per_node), + "agent_count": int(args.agent_count), + "inner_workers": int(args.inner_workers), + "inner_threads": int(args.inner_threads), + "max_heavy_workers": int(args.max_heavy_workers), } ) futures = [] root = str(Path(__file__).resolve().parents[1]) - for rank, node_ip in enumerate(node_ips): - resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)} + for rank, (node_id, node_ip) in enumerate(node_entries): + resources: dict[str, float] = {} + tpu_per_task = float(args.tpu_per_task) + if tpu_per_task > 0.0: + resources["TPU"] = tpu_per_task futures.append( - _train_on_node.options(resources=resources).remote( + _train_on_node.options( + resources=resources, + num_cpus=float(args.worker_cpus), + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ), + ).remote( root=root, run_kind=str(args.run_kind), entry_args=entry_args, + node_id=node_id, + node_ip=node_ip, rank=rank, world_size=world_size, coordinator_ip=coordinator_ip, @@ -197,6 +572,11 @@ def main() -> None: output_root=str(args.output_root), wandb_entity=str(args.wandb_entity), wandb_project=str(args.wandb_project), + agents_per_node=int(args.agents_per_node), + agent_count=int(args.agent_count), + inner_workers=int(args.inner_workers), + inner_threads=int(args.inner_threads), + max_heavy_workers=int(args.max_heavy_workers), sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"), ) ) diff --git a/submit_ray_job.sh b/submit_ray_job.sh index c2f9709..a6065ec 100755 --- a/submit_ray_job.sh +++ b/submit_ray_job.sh @@ -4,6 +4,7 @@ # RAY_MODE=single -> one run (default) # RAY_MODE=distributed -> one run per TPU node (experimental) # RAY_MODE=benchmark -> one benchmark run per TPU node (overnight) +# RAY_MODE=sweep -> distributed W&B sweep agents set -euo pipefail @@ -28,7 +29,14 @@ env = dotenv_values(".env") # Filter out empty/None values env_vars = {k: v for k, v in env.items() if v} env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0")) -for k in ("WANDB_ENTITY", "WANDB_PROJECT", "PHANTOM_BENCHMARK_COMPARE_ROBUST"): +for k in ( + "WANDB_ENTITY", + "WANDB_PROJECT", + "PHANTOM_BENCHMARK_COMPARE_ROBUST", + "PHANTOM_JAX_PLATFORM", + "PHANTOM_ALLOW_MULTI_NODE_TPU", + "PHANTOM_TPU_AGENT_SLOTS", +): if os.getenv(k): env_vars[k] = os.getenv(k) @@ -52,6 +60,15 @@ print(json.dumps({ RAY_MODE="${RAY_MODE:-single}" TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}" BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}" +INNER_WORKERS="${INNER_WORKERS:-16}" +INNER_THREADS="${INNER_THREADS:-1}" +MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}" +WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}" +SWEEP_KIND="${SWEEP_KIND:-benchmark}" +SWEEP_METHOD="${SWEEP_METHOD:-random}" +SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}" +AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}" +AGENT_COUNT="${AGENT_COUNT:-0}" SUBMIT_ARGS=() if [ "${RAY_NO_WAIT:-0}" = "1" ]; then @@ -104,6 +121,10 @@ if [ "$RAY_MODE" = "benchmark" ]; then --output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}" --wandb-entity "${WANDB_ENTITY:-lusiana}" --wandb-project "${WANDB_PROJECT:-capstone_tpu}" + --inner-workers "${INNER_WORKERS}" + --inner-threads "${INNER_THREADS}" + --max-heavy-workers "${MAX_HEAVY_WORKERS}" + --worker-cpus "${WORKER_CPUS}" ) if [ "${COMPARE_ROBUST:-1}" = "1" ]; then DIST_ARGS+=(--compare-robust) @@ -112,5 +133,97 @@ if [ "$RAY_MODE" = "benchmark" ]; then exit 0 fi -echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', or 'benchmark')." >&2 +if [ "$RAY_MODE" = "sweep" ]; then + SWEEP_PROJECT="${WANDB_PROJECT:-capstone_tpu}" + SWEEP_ENTITY="${WANDB_ENTITY:-lusiana}" + SWEEP_ID_VALUE="${SWEEP_ID:-}" + SWEEP_NUM_NODES="${NUM_NODES:-5}" + PY_SWEEP_BIN="${PY_SWEEP_BIN:-}" + if [ -z "$PY_SWEEP_BIN" ]; then + for cand in "$ROOT/.venv/bin/python" "$ROOT/.venv-ray/bin/python" python3 python; do + if [ "$cand" = "python3" ] || [ "$cand" = "python" ]; then + command -v "$cand" >/dev/null 2>&1 || continue + elif [ ! -x "$cand" ]; then + continue + fi + if "$cand" - <<'PY' >/dev/null 2>&1 +import sys +from pathlib import Path +cwd = str(Path.cwd()) +sys.path = [p for p in sys.path if p not in {'', cwd}] +import wandb +print(wandb.__name__) +PY + then + PY_SWEEP_BIN="$cand" + break + fi + done + fi + if [ -z "$PY_SWEEP_BIN" ]; then + echo "No python interpreter with wandb is available for sweep creation." >&2 + exit 1 + fi + + if [ -z "$SWEEP_ID_VALUE" ]; then + if [ -z "${WANDB_API_KEY:-}" ]; then + export WANDB_API_KEY + WANDB_API_KEY="$($PY_SWEEP_BIN - <<'PY' +from dotenv import dotenv_values +print(dotenv_values('.env').get('WANDB_API_KEY', '').strip()) +PY +)" + fi + if [ -z "${WANDB_API_KEY:-}" ]; then + echo "WANDB_API_KEY is required to create a sweep." >&2 + exit 1 + fi + SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \ + --kind "$SWEEP_KIND" \ + --project "$SWEEP_PROJECT" \ + --entity "$SWEEP_ENTITY" \ + --method "$SWEEP_METHOD" \ + --run-cap "$SWEEP_RUN_CAP")" + fi + + SWEEP_ENTRY_ARGS="${SWEEP_ENTRY_ARGS:-}" + if [ -z "$SWEEP_ENTRY_ARGS" ]; then + SWEEP_ENTRY_ARGS="--sweep-agent --sweep-id $SWEEP_ID_VALUE --project $SWEEP_PROJECT --device cpu" + fi + + if [ "$AGENT_COUNT" = "0" ] && [ "${SWEEP_RUN_CAP:-0}" -gt 0 ]; then + TOTAL_AGENTS=$((SWEEP_NUM_NODES * AGENTS_PER_NODE)) + if [ "$TOTAL_AGENTS" -gt 0 ]; then + AGENT_COUNT=$(((SWEEP_RUN_CAP + TOTAL_AGENTS - 1) / TOTAL_AGENTS)) + echo "Derived AGENT_COUNT=$AGENT_COUNT from SWEEP_RUN_CAP=$SWEEP_RUN_CAP across $TOTAL_AGENTS agents" + fi + fi + + DIST_ARGS=( + python + scripts/ray_distributed_train.py + --run-kind "$SWEEP_KIND" + --entry-args "$SWEEP_ENTRY_ARGS" + --num-nodes "${SWEEP_NUM_NODES}" + --tpu-per-task "${TPU_PER_TASK:-0}" + --base-seed "${BASE_SEED:-42}" + --wandb-entity "$SWEEP_ENTITY" + --wandb-project "$SWEEP_PROJECT" + --agents-per-node "$AGENTS_PER_NODE" + --agent-count "$AGENT_COUNT" + --inner-threads "$INNER_THREADS" + --worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}" + ) + if [ "$SWEEP_KIND" = "benchmark" ]; then + DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}") + fi + if [ "${COMPARE_ROBUST:-0}" = "1" ]; then + DIST_ARGS+=(--compare-robust) + fi + echo "SWEEP_ID=$SWEEP_ID_VALUE" + "$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}" + exit 0 +fi + +echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', 'benchmark', or 'sweep')." >&2 exit 1 diff --git a/tpu_orchestration/watchdog.sh b/tpu_orchestration/watchdog.sh index 4c32562..7e7a0fc 100755 --- a/tpu_orchestration/watchdog.sh +++ b/tpu_orchestration/watchdog.sh @@ -97,6 +97,8 @@ while true; do # Determine runtime version RT_VERSION=${RUNTIME_VERSION:-"tpu-ubuntu2204-base"} + CREATE_LOG="/tmp/tpu_create_${QR_NAME}.log" + gcloud compute tpus queued-resources create $QR_NAME \ --project=$PROJECT_ID \ --node-id=$QR_NAME \ @@ -104,20 +106,23 @@ while true; do --accelerator-type=$ACCEL_TYPE \ --runtime-version=$RT_VERSION \ $SPOT_FLAG \ + --internal-ips \ --metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \ - --metadata "$METADATA" 2>&1 | tee /tmp/tpu_create_${QR_NAME}.log + --metadata "$METADATA" 2>&1 | tee "$CREATE_LOG" + + CREATE_EXIT=${PIPESTATUS[0]} - if [ $? -eq 0 ]; then + if [ $CREATE_EXIT -eq 0 ]; then echo "[$(date)] Successfully queued $QR_NAME." RETRY_DELAY=60 - elif grep -q "IN_USE_ADDRESSES" /tmp/tpu_create_${QR_NAME}.log 2>/dev/null; then + elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s" sleep $RETRY_DELAY RETRY_DELAY=$((RETRY_DELAY * 2)) [ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY continue else - echo "[$(date)] Failed to queue $QR_NAME." + echo "[$(date)] Failed to queue $QR_NAME (exit=$CREATE_EXIT)." RETRY_DELAY=60 fi else