mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
setup for tpu orchestarion properly
This commit is contained in:
151
scripts/ray_distributed_train.py
Normal file
151
scripts/ray_distributed_train.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def _has_flag(tokens: list[str], name: str) -> bool:
|
||||
return any(tok == name or tok.startswith(f"{name}=") for tok in tokens)
|
||||
|
||||
|
||||
def _alive_node_ips() -> list[str]:
|
||||
seen: set[str] = set()
|
||||
ips: list[str] = []
|
||||
for node in ray.nodes():
|
||||
if not bool(node.get("Alive", False)):
|
||||
continue
|
||||
ip = str(node.get("NodeManagerAddress", "")).strip()
|
||||
if not ip or ip in seen:
|
||||
continue
|
||||
seen.add(ip)
|
||||
ips.append(ip)
|
||||
return sorted(ips)
|
||||
|
||||
|
||||
@ray.remote(max_retries=0)
|
||||
def _train_on_node(
|
||||
*,
|
||||
root: str,
|
||||
train_args: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
coordinator_ip: str,
|
||||
coordinator_port: int,
|
||||
base_seed: int,
|
||||
run_group: str,
|
||||
sync_jax: bool,
|
||||
) -> int:
|
||||
env = dict(os.environ)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
|
||||
if world_size > 1 and requested_platform == "tpu":
|
||||
requested_platform = "cpu"
|
||||
print(
|
||||
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs"
|
||||
)
|
||||
env["JAX_PLATFORMS"] = requested_platform
|
||||
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||
|
||||
cwd = str(Path(root))
|
||||
|
||||
try:
|
||||
subprocess.run(["make", "data.pull"], cwd=cwd, env=env, check=True)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
pull_cmd = [sys.executable, "scripts/hf_data.py", "pull"]
|
||||
subprocess.run(pull_cmd, cwd=cwd, env=env, check=True)
|
||||
|
||||
if sync_jax and requested_platform == "tpu":
|
||||
env_probe = dict(env)
|
||||
env_probe["CLOUD_TPU_TASK_ID"] = str(rank)
|
||||
probe = (
|
||||
"import jax; "
|
||||
f"jax.distributed.initialize(coordinator_address='{coordinator_ip}:{coordinator_port}', "
|
||||
f"num_processes={world_size}, process_id={rank}); "
|
||||
"print('JAX_SYNC', jax.process_index(), jax.device_count(), jax.local_device_count())"
|
||||
)
|
||||
subprocess.run(
|
||||
[sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True
|
||||
)
|
||||
|
||||
tokens = shlex.split(train_args)
|
||||
if not _has_flag(tokens, "--seed"):
|
||||
tokens.extend(["--seed", str(base_seed + rank)])
|
||||
if not _has_flag(tokens, "--group"):
|
||||
tokens.extend(["--group", run_group])
|
||||
|
||||
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||
proc = subprocess.run(cmd, cwd=cwd, env=env)
|
||||
return int(proc.returncode)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch one train run per Ray TPU node"
|
||||
)
|
||||
parser.add_argument("--train-args", type=str, required=True)
|
||||
parser.add_argument("--num-nodes", type=int, default=0)
|
||||
parser.add_argument("--tpu-per-task", type=float, default=8.0)
|
||||
parser.add_argument("--base-seed", type=int, default=42)
|
||||
parser.add_argument("--sync-jax", action="store_true")
|
||||
parser.add_argument("--coordinator-port", type=int, default=12355)
|
||||
parser.add_argument("--run-group", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
node_ips = _alive_node_ips()
|
||||
if not node_ips:
|
||||
raise RuntimeError("No alive Ray nodes found")
|
||||
|
||||
requested = int(args.num_nodes)
|
||||
if requested > 0:
|
||||
node_ips = node_ips[:requested]
|
||||
|
||||
world_size = len(node_ips)
|
||||
coordinator_ip = node_ips[0]
|
||||
run_group = args.run_group or f"ray-dist-{int(time.time())}"
|
||||
|
||||
print(
|
||||
{
|
||||
"nodes": node_ips,
|
||||
"world_size": world_size,
|
||||
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
|
||||
"train_args": args.train_args,
|
||||
"run_group": run_group,
|
||||
}
|
||||
)
|
||||
|
||||
futures = []
|
||||
root = str(Path(__file__).resolve().parents[1])
|
||||
for rank, node_ip in enumerate(node_ips):
|
||||
resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)}
|
||||
futures.append(
|
||||
_train_on_node.options(resources=resources).remote(
|
||||
root=root,
|
||||
train_args=args.train_args,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
coordinator_ip=coordinator_ip,
|
||||
coordinator_port=int(args.coordinator_port),
|
||||
base_seed=int(args.base_seed),
|
||||
run_group=run_group,
|
||||
sync_jax=bool(args.sync_jax),
|
||||
)
|
||||
)
|
||||
|
||||
results = ray.get(futures)
|
||||
failed = [code for code in results if int(code) != 0]
|
||||
if failed:
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user