From 9caad4de4e0f46b6e6a80546921b006e6513b17f Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 12 Mar 2026 00:22:09 +0100 Subject: [PATCH] setup for tpu orchestarion properly --- docker/TPUWatchdog.dockerfile | 70 ++++++++++++++ engine/__init__.py | 0 scripts/ray_distributed_train.py | 151 +++++++++++++++++++++++++++++++ submit_ray_job.sh | 81 +++++++++++++++++ 4 files changed, 302 insertions(+) create mode 100644 docker/TPUWatchdog.dockerfile create mode 100644 engine/__init__.py create mode 100644 scripts/ray_distributed_train.py create mode 100755 submit_ray_job.sh diff --git a/docker/TPUWatchdog.dockerfile b/docker/TPUWatchdog.dockerfile new file mode 100644 index 0000000..8299171 --- /dev/null +++ b/docker/TPUWatchdog.dockerfile @@ -0,0 +1,70 @@ +FROM google/cloud-sdk:slim + +# Install tmux to manage multiple watchdogs and jq for json parsing +RUN apt-get update && \ + apt-get install -y tmux jq && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy the orchestration scripts and configs +COPY tpu_orchestration/ /app/tpu_orchestration/ + +# Make sure scripts are executable +RUN chmod +x /app/tpu_orchestration/watchdog.sh +RUN chmod +x /app/tpu_orchestration/tpu_startup.sh + +# Create an entrypoint script that launches a watchdog for each config +COPY <<-'EOF' /app/entrypoint.sh +#!/bin/bash +set -e + +# Make sure required variables are set +if [ -z "$HF_TOKEN" ]; then + echo "Error: HF_TOKEN environment variable is required." + exit 1 +fi + +if [ -z "$WANDB_API_KEY" ]; then + echo "Warning: WANDB_API_KEY environment variable is not set. Wandb logging may fail on TPUs." +fi + +# Authenticate gcloud if credentials are provided +if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENTIALS" ]; then + CRED_TYPE=$(jq -r '.type' "$GOOGLE_APPLICATION_CREDENTIALS" 2>/dev/null || echo "unknown") + if [ "$CRED_TYPE" = "service_account" ]; then + echo "Authenticating gcloud using service account key..." + gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS" + + # Extract project ID from the key file + PROJECT_ID=$(jq -r '.project_id' "$GOOGLE_APPLICATION_CREDENTIALS") + if [ -n "$PROJECT_ID" ] && [ "$PROJECT_ID" != "null" ]; then + gcloud config set project "$PROJECT_ID" + echo "Set project to $PROJECT_ID" + fi + else + echo "Note: Using application default credentials or mounted gcloud config..." + fi +else + echo "Note: Assuming gcloud config is mounted from host." +fi + +# Run the watchdogs in the background using bash instead of tmux +# Tmux needs a TTY to attach properly which we might not have in docker +# Stagger startups by 15s to prevent simultaneous TPU creation quota hits +DELAY=0 +for conf in /app/tpu_orchestration/configs/*.conf; do + echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)" + (sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") & + DELAY=$((DELAY + 15)) +done + +echo "All watchdogs queued with staggered startup." + +# Keep the container running +wait +EOF + +RUN chmod +x /app/entrypoint.sh + +CMD ["/app/entrypoint.sh"] \ No newline at end of file diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/ray_distributed_train.py b/scripts/ray_distributed_train.py new file mode 100644 index 0000000..f918f33 --- /dev/null +++ b/scripts/ray_distributed_train.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import argparse +import os +import shlex +import subprocess +import sys +import time +from pathlib import Path + +import ray + + +def _has_flag(tokens: list[str], name: str) -> bool: + return any(tok == name or tok.startswith(f"{name}=") for tok in tokens) + + +def _alive_node_ips() -> list[str]: + seen: set[str] = set() + ips: list[str] = [] + for node in ray.nodes(): + if not bool(node.get("Alive", False)): + continue + ip = str(node.get("NodeManagerAddress", "")).strip() + if not ip or ip in seen: + continue + seen.add(ip) + ips.append(ip) + return sorted(ips) + + +@ray.remote(max_retries=0) +def _train_on_node( + *, + root: str, + train_args: str, + rank: int, + world_size: int, + coordinator_ip: str, + coordinator_port: int, + base_seed: int, + run_group: str, + sync_jax: bool, +) -> int: + env = dict(os.environ) + env["PYTHONUNBUFFERED"] = "1" + requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower() + if world_size > 1 and requested_platform == "tpu": + requested_platform = "cpu" + print( + "PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs" + ) + env["JAX_PLATFORMS"] = requested_platform + # Keep each train process in single-host mode to avoid accidental global stalls. + env["CLOUD_TPU_TASK_ID"] = "0" + + cwd = str(Path(root)) + + try: + subprocess.run(["make", "data.pull"], cwd=cwd, env=env, check=True) + except (subprocess.SubprocessError, OSError): + pull_cmd = [sys.executable, "scripts/hf_data.py", "pull"] + subprocess.run(pull_cmd, cwd=cwd, env=env, check=True) + + if sync_jax and requested_platform == "tpu": + env_probe = dict(env) + env_probe["CLOUD_TPU_TASK_ID"] = str(rank) + probe = ( + "import jax; " + f"jax.distributed.initialize(coordinator_address='{coordinator_ip}:{coordinator_port}', " + f"num_processes={world_size}, process_id={rank}); " + "print('JAX_SYNC', jax.process_index(), jax.device_count(), jax.local_device_count())" + ) + subprocess.run( + [sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True + ) + + tokens = shlex.split(train_args) + if not _has_flag(tokens, "--seed"): + tokens.extend(["--seed", str(base_seed + rank)]) + if not _has_flag(tokens, "--group"): + tokens.extend(["--group", run_group]) + + cmd = [sys.executable, "-m", "engine.train", *tokens] + proc = subprocess.run(cmd, cwd=cwd, env=env) + return int(proc.returncode) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Launch one train run per Ray TPU node" + ) + parser.add_argument("--train-args", type=str, required=True) + parser.add_argument("--num-nodes", type=int, default=0) + parser.add_argument("--tpu-per-task", type=float, default=8.0) + parser.add_argument("--base-seed", type=int, default=42) + parser.add_argument("--sync-jax", action="store_true") + parser.add_argument("--coordinator-port", type=int, default=12355) + parser.add_argument("--run-group", type=str, default="") + args = parser.parse_args() + + ray.init(address="auto") + + node_ips = _alive_node_ips() + if not node_ips: + raise RuntimeError("No alive Ray nodes found") + + requested = int(args.num_nodes) + if requested > 0: + node_ips = node_ips[:requested] + + world_size = len(node_ips) + coordinator_ip = node_ips[0] + run_group = args.run_group or f"ray-dist-{int(time.time())}" + + print( + { + "nodes": node_ips, + "world_size": world_size, + "coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}", + "train_args": args.train_args, + "run_group": run_group, + } + ) + + futures = [] + root = str(Path(__file__).resolve().parents[1]) + for rank, node_ip in enumerate(node_ips): + resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)} + futures.append( + _train_on_node.options(resources=resources).remote( + root=root, + train_args=args.train_args, + rank=rank, + world_size=world_size, + coordinator_ip=coordinator_ip, + coordinator_port=int(args.coordinator_port), + base_seed=int(args.base_seed), + run_group=run_group, + sync_jax=bool(args.sync_jax), + ) + ) + + results = ray.get(futures) + failed = [code for code in results if int(code) != 0] + if failed: + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/submit_ray_job.sh b/submit_ray_job.sh new file mode 100755 index 0000000..11775d6 --- /dev/null +++ b/submit_ray_job.sh @@ -0,0 +1,81 @@ +#!/bin/bash +# Submits PHANTOM training to a Ray cluster with .env injection. +# Modes: +# RAY_MODE=single -> one run (default) +# RAY_MODE=distributed -> one run per TPU node (experimental) + +set -euo pipefail + +ROOT="/home/velocitatem/Documents/Projects/PHANTOM" +RAY_BIN="${RAY_BIN:-ray}" +if ! command -v "$RAY_BIN" >/dev/null 2>&1; then + if [ -x "$ROOT/.venv-ray/bin/ray" ]; then + RAY_BIN="$ROOT/.venv-ray/bin/ray" + else + echo "ray CLI not found. Activate .venv-ray or set RAY_BIN." >&2 + exit 1 + fi +fi + +# 1. Parse .env and generate the JSON payload for Ray +export RUNTIME_ENV_JSON=$(python -c ' +import json +import os +from dotenv import dotenv_values + +env = dotenv_values(".env") +# Filter out empty/None values +env_vars = {k: v for k, v in env.items() if v} +env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0")) + +print(json.dumps({ + "pip": [ + "stable-baselines3>=2.2.0", + "gymnasium>=0.29.0", + "wandb", + "tensorboard", + "python-dotenv", + "pandas", + "pydantic", + "graphviz", + "huggingface_hub" + ], + "env_vars": env_vars +})) +') + +RAY_MODE="${RAY_MODE:-single}" +TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}" + +COMMON_ARGS=( + job submit + --address http://localhost:8265 + --working-dir "$ROOT" + --runtime-env-json "$RUNTIME_ENV_JSON" + -- +) + +if [ "$RAY_MODE" = "single" ]; then + read -r -a TRAIN_TOKENS <<< "$TRAIN_ARGS" + "$RAY_BIN" "${COMMON_ARGS[@]}" python -m engine.train "${TRAIN_TOKENS[@]}" + exit 0 +fi + +if [ "$RAY_MODE" = "distributed" ]; then + DIST_ARGS=( + python + scripts/ray_distributed_train.py + --train-args "$TRAIN_ARGS" + --num-nodes "${NUM_NODES:-4}" + --tpu-per-task "${TPU_PER_TASK:-8}" + --base-seed "${BASE_SEED:-42}" + ) + if [ "${SYNC_JAX:-0}" = "1" ]; then + DIST_ARGS+=(--sync-jax) + fi + "$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}" + exit 0 +fi + +echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single' or 'distributed')." >&2 +exit 1