mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
setup for tpu orchestarion properly
This commit is contained in:
70
docker/TPUWatchdog.dockerfile
Normal file
70
docker/TPUWatchdog.dockerfile
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
FROM google/cloud-sdk:slim
|
||||||
|
|
||||||
|
# Install tmux to manage multiple watchdogs and jq for json parsing
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y tmux jq && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the orchestration scripts and configs
|
||||||
|
COPY tpu_orchestration/ /app/tpu_orchestration/
|
||||||
|
|
||||||
|
# Make sure scripts are executable
|
||||||
|
RUN chmod +x /app/tpu_orchestration/watchdog.sh
|
||||||
|
RUN chmod +x /app/tpu_orchestration/tpu_startup.sh
|
||||||
|
|
||||||
|
# Create an entrypoint script that launches a watchdog for each config
|
||||||
|
COPY <<-'EOF' /app/entrypoint.sh
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Make sure required variables are set
|
||||||
|
if [ -z "$HF_TOKEN" ]; then
|
||||||
|
echo "Error: HF_TOKEN environment variable is required."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$WANDB_API_KEY" ]; then
|
||||||
|
echo "Warning: WANDB_API_KEY environment variable is not set. Wandb logging may fail on TPUs."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Authenticate gcloud if credentials are provided
|
||||||
|
if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENTIALS" ]; then
|
||||||
|
CRED_TYPE=$(jq -r '.type' "$GOOGLE_APPLICATION_CREDENTIALS" 2>/dev/null || echo "unknown")
|
||||||
|
if [ "$CRED_TYPE" = "service_account" ]; then
|
||||||
|
echo "Authenticating gcloud using service account key..."
|
||||||
|
gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS"
|
||||||
|
|
||||||
|
# Extract project ID from the key file
|
||||||
|
PROJECT_ID=$(jq -r '.project_id' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||||
|
if [ -n "$PROJECT_ID" ] && [ "$PROJECT_ID" != "null" ]; then
|
||||||
|
gcloud config set project "$PROJECT_ID"
|
||||||
|
echo "Set project to $PROJECT_ID"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Note: Using application default credentials or mounted gcloud config..."
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Note: Assuming gcloud config is mounted from host."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the watchdogs in the background using bash instead of tmux
|
||||||
|
# Tmux needs a TTY to attach properly which we might not have in docker
|
||||||
|
# Stagger startups by 15s to prevent simultaneous TPU creation quota hits
|
||||||
|
DELAY=0
|
||||||
|
for conf in /app/tpu_orchestration/configs/*.conf; do
|
||||||
|
echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)"
|
||||||
|
(sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") &
|
||||||
|
DELAY=$((DELAY + 15))
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "All watchdogs queued with staggered startup."
|
||||||
|
|
||||||
|
# Keep the container running
|
||||||
|
wait
|
||||||
|
EOF
|
||||||
|
|
||||||
|
RUN chmod +x /app/entrypoint.sh
|
||||||
|
|
||||||
|
CMD ["/app/entrypoint.sh"]
|
||||||
0
engine/__init__.py
Normal file
0
engine/__init__.py
Normal file
151
scripts/ray_distributed_train.py
Normal file
151
scripts/ray_distributed_train.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
def _has_flag(tokens: list[str], name: str) -> bool:
|
||||||
|
return any(tok == name or tok.startswith(f"{name}=") for tok in tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def _alive_node_ips() -> list[str]:
|
||||||
|
seen: set[str] = set()
|
||||||
|
ips: list[str] = []
|
||||||
|
for node in ray.nodes():
|
||||||
|
if not bool(node.get("Alive", False)):
|
||||||
|
continue
|
||||||
|
ip = str(node.get("NodeManagerAddress", "")).strip()
|
||||||
|
if not ip or ip in seen:
|
||||||
|
continue
|
||||||
|
seen.add(ip)
|
||||||
|
ips.append(ip)
|
||||||
|
return sorted(ips)
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(max_retries=0)
|
||||||
|
def _train_on_node(
|
||||||
|
*,
|
||||||
|
root: str,
|
||||||
|
train_args: str,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
coordinator_ip: str,
|
||||||
|
coordinator_port: int,
|
||||||
|
base_seed: int,
|
||||||
|
run_group: str,
|
||||||
|
sync_jax: bool,
|
||||||
|
) -> int:
|
||||||
|
env = dict(os.environ)
|
||||||
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
|
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
|
||||||
|
if world_size > 1 and requested_platform == "tpu":
|
||||||
|
requested_platform = "cpu"
|
||||||
|
print(
|
||||||
|
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs"
|
||||||
|
)
|
||||||
|
env["JAX_PLATFORMS"] = requested_platform
|
||||||
|
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||||
|
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||||
|
|
||||||
|
cwd = str(Path(root))
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(["make", "data.pull"], cwd=cwd, env=env, check=True)
|
||||||
|
except (subprocess.SubprocessError, OSError):
|
||||||
|
pull_cmd = [sys.executable, "scripts/hf_data.py", "pull"]
|
||||||
|
subprocess.run(pull_cmd, cwd=cwd, env=env, check=True)
|
||||||
|
|
||||||
|
if sync_jax and requested_platform == "tpu":
|
||||||
|
env_probe = dict(env)
|
||||||
|
env_probe["CLOUD_TPU_TASK_ID"] = str(rank)
|
||||||
|
probe = (
|
||||||
|
"import jax; "
|
||||||
|
f"jax.distributed.initialize(coordinator_address='{coordinator_ip}:{coordinator_port}', "
|
||||||
|
f"num_processes={world_size}, process_id={rank}); "
|
||||||
|
"print('JAX_SYNC', jax.process_index(), jax.device_count(), jax.local_device_count())"
|
||||||
|
)
|
||||||
|
subprocess.run(
|
||||||
|
[sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = shlex.split(train_args)
|
||||||
|
if not _has_flag(tokens, "--seed"):
|
||||||
|
tokens.extend(["--seed", str(base_seed + rank)])
|
||||||
|
if not _has_flag(tokens, "--group"):
|
||||||
|
tokens.extend(["--group", run_group])
|
||||||
|
|
||||||
|
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||||
|
proc = subprocess.run(cmd, cwd=cwd, env=env)
|
||||||
|
return int(proc.returncode)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Launch one train run per Ray TPU node"
|
||||||
|
)
|
||||||
|
parser.add_argument("--train-args", type=str, required=True)
|
||||||
|
parser.add_argument("--num-nodes", type=int, default=0)
|
||||||
|
parser.add_argument("--tpu-per-task", type=float, default=8.0)
|
||||||
|
parser.add_argument("--base-seed", type=int, default=42)
|
||||||
|
parser.add_argument("--sync-jax", action="store_true")
|
||||||
|
parser.add_argument("--coordinator-port", type=int, default=12355)
|
||||||
|
parser.add_argument("--run-group", type=str, default="")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ray.init(address="auto")
|
||||||
|
|
||||||
|
node_ips = _alive_node_ips()
|
||||||
|
if not node_ips:
|
||||||
|
raise RuntimeError("No alive Ray nodes found")
|
||||||
|
|
||||||
|
requested = int(args.num_nodes)
|
||||||
|
if requested > 0:
|
||||||
|
node_ips = node_ips[:requested]
|
||||||
|
|
||||||
|
world_size = len(node_ips)
|
||||||
|
coordinator_ip = node_ips[0]
|
||||||
|
run_group = args.run_group or f"ray-dist-{int(time.time())}"
|
||||||
|
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"nodes": node_ips,
|
||||||
|
"world_size": world_size,
|
||||||
|
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
|
||||||
|
"train_args": args.train_args,
|
||||||
|
"run_group": run_group,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
root = str(Path(__file__).resolve().parents[1])
|
||||||
|
for rank, node_ip in enumerate(node_ips):
|
||||||
|
resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)}
|
||||||
|
futures.append(
|
||||||
|
_train_on_node.options(resources=resources).remote(
|
||||||
|
root=root,
|
||||||
|
train_args=args.train_args,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
coordinator_ip=coordinator_ip,
|
||||||
|
coordinator_port=int(args.coordinator_port),
|
||||||
|
base_seed=int(args.base_seed),
|
||||||
|
run_group=run_group,
|
||||||
|
sync_jax=bool(args.sync_jax),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = ray.get(futures)
|
||||||
|
failed = [code for code in results if int(code) != 0]
|
||||||
|
if failed:
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
81
submit_ray_job.sh
Executable file
81
submit_ray_job.sh
Executable file
@@ -0,0 +1,81 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Submits PHANTOM training to a Ray cluster with .env injection.
|
||||||
|
# Modes:
|
||||||
|
# RAY_MODE=single -> one run (default)
|
||||||
|
# RAY_MODE=distributed -> one run per TPU node (experimental)
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ROOT="/home/velocitatem/Documents/Projects/PHANTOM"
|
||||||
|
RAY_BIN="${RAY_BIN:-ray}"
|
||||||
|
if ! command -v "$RAY_BIN" >/dev/null 2>&1; then
|
||||||
|
if [ -x "$ROOT/.venv-ray/bin/ray" ]; then
|
||||||
|
RAY_BIN="$ROOT/.venv-ray/bin/ray"
|
||||||
|
else
|
||||||
|
echo "ray CLI not found. Activate .venv-ray or set RAY_BIN." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 1. Parse .env and generate the JSON payload for Ray
|
||||||
|
export RUNTIME_ENV_JSON=$(python -c '
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dotenv import dotenv_values
|
||||||
|
|
||||||
|
env = dotenv_values(".env")
|
||||||
|
# Filter out empty/None values
|
||||||
|
env_vars = {k: v for k, v in env.items() if v}
|
||||||
|
env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0"))
|
||||||
|
|
||||||
|
print(json.dumps({
|
||||||
|
"pip": [
|
||||||
|
"stable-baselines3>=2.2.0",
|
||||||
|
"gymnasium>=0.29.0",
|
||||||
|
"wandb",
|
||||||
|
"tensorboard",
|
||||||
|
"python-dotenv",
|
||||||
|
"pandas",
|
||||||
|
"pydantic",
|
||||||
|
"graphviz",
|
||||||
|
"huggingface_hub"
|
||||||
|
],
|
||||||
|
"env_vars": env_vars
|
||||||
|
}))
|
||||||
|
')
|
||||||
|
|
||||||
|
RAY_MODE="${RAY_MODE:-single}"
|
||||||
|
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}"
|
||||||
|
|
||||||
|
COMMON_ARGS=(
|
||||||
|
job submit
|
||||||
|
--address http://localhost:8265
|
||||||
|
--working-dir "$ROOT"
|
||||||
|
--runtime-env-json "$RUNTIME_ENV_JSON"
|
||||||
|
--
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$RAY_MODE" = "single" ]; then
|
||||||
|
read -r -a TRAIN_TOKENS <<< "$TRAIN_ARGS"
|
||||||
|
"$RAY_BIN" "${COMMON_ARGS[@]}" python -m engine.train "${TRAIN_TOKENS[@]}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$RAY_MODE" = "distributed" ]; then
|
||||||
|
DIST_ARGS=(
|
||||||
|
python
|
||||||
|
scripts/ray_distributed_train.py
|
||||||
|
--train-args "$TRAIN_ARGS"
|
||||||
|
--num-nodes "${NUM_NODES:-4}"
|
||||||
|
--tpu-per-task "${TPU_PER_TASK:-8}"
|
||||||
|
--base-seed "${BASE_SEED:-42}"
|
||||||
|
)
|
||||||
|
if [ "${SYNC_JAX:-0}" = "1" ]; then
|
||||||
|
DIST_ARGS+=(--sync-jax)
|
||||||
|
fi
|
||||||
|
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single' or 'distributed')." >&2
|
||||||
|
exit 1
|
||||||
Reference in New Issue
Block a user