chore: bulk tpu reorchestration

This commit is contained in:
2026-03-15 21:14:41 +01:00
parent 52b4dcdce3
commit a9c091050c
10 changed files with 155 additions and 42 deletions

View File

@@ -92,9 +92,9 @@ def _truthy(value: str | bool | None) -> bool:
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _alive_nodes() -> list[tuple[str, str]]:
def _alive_nodes() -> list[tuple[str, str, bool, float]]:
seen: set[str] = set()
nodes: list[tuple[str, str]] = []
nodes: list[tuple[str, str, bool, float]] = []
for node in ray.nodes():
if not bool(node.get("Alive", False)):
continue
@@ -102,9 +102,50 @@ def _alive_nodes() -> list[tuple[str, str]]:
ip = str(node.get("NodeManagerAddress", "")).strip()
if not node_id or not ip or node_id in seen:
continue
resources = node.get("Resources", {}) or {}
is_head = bool(resources.get("node:__internal_head__", 0.0))
tpu = float(resources.get("TPU", 0.0))
seen.add(node_id)
nodes.append((node_id, ip))
return sorted(nodes, key=lambda item: (item[1], item[0]))
nodes.append((node_id, ip, is_head, tpu))
return sorted(nodes, key=lambda item: (item[1], item[2], -item[3], item[0]))
def _dedupe_nodes_for_tpu(
nodes: list[tuple[str, str, bool, float]],
) -> tuple[list[tuple[str, str]], list[dict[str, str | float | bool]]]:
selected: dict[str, tuple[str, str, bool, float]] = {}
dropped: list[dict[str, str | float | bool]] = []
def _score(item: tuple[str, str, bool, float]) -> tuple[int, float, str]:
node_id, _ip, is_head, tpu = item
return (1 if bool(is_head) else 0, -float(tpu), str(node_id))
for item in nodes:
node_id, ip, is_head, tpu = item
existing = selected.get(ip)
if existing is None:
selected[ip] = item
continue
keep, drop = (
(item, existing) if _score(item) < _score(existing) else (existing, item)
)
selected[ip] = keep
dropped.append(
{
"ip": str(ip),
"dropped_node_id": str(drop[0]),
"dropped_is_head": bool(drop[2]),
"dropped_tpu": float(drop[3]),
"kept_node_id": str(keep[0]),
"kept_is_head": bool(keep[2]),
"kept_tpu": float(keep[3]),
}
)
entries = [(node_id, ip) for ip, (node_id, _ip, _is_head, _tpu) in selected.items()]
entries.sort(key=lambda item: (item[1], item[0]))
return entries, dropped
def _benchmark_cells(
@@ -369,8 +410,19 @@ def _train_on_node(
env["JAX_PLATFORM_NAME"] = "cpu"
else:
env.pop("JAX_PLATFORM_NAME", None)
# Keep each train process in single-host mode to avoid accidental global stalls.
env["CLOUD_TPU_TASK_ID"] = "0"
if requested_platform == "tpu" and world_size > 1 and allow_multi_node_tpu:
env["CLOUD_TPU_TASK_ID"] = str(int(rank))
print(
{
"rank": int(rank),
"node_ip": str(node_ip),
"jax_platform": "tpu",
"cloud_tpu_task_id": str(env["CLOUD_TPU_TASK_ID"]),
}
)
else:
# Keep each process in single-host mode when TPU multi-host is disabled.
env["CLOUD_TPU_TASK_ID"] = "0"
if run_kind == "benchmark":
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
if wandb_entity:
@@ -508,12 +560,36 @@ def main() -> None:
ray.init(address="auto")
node_entries = _alive_nodes()
if not node_entries:
node_records = _alive_nodes()
if not node_records:
raise RuntimeError("No alive Ray nodes found")
if float(args.tpu_per_task) > 0.0:
node_entries, dropped = _dedupe_nodes_for_tpu(node_records)
if dropped:
print(
{
"tpu_host_dedupe": True,
"alive_ray_nodes": len(node_records),
"unique_tpu_hosts": len(node_entries),
"dropped": dropped,
}
)
else:
node_entries = [
(node_id, node_ip) for node_id, node_ip, _is_head, _tpu in node_records
]
requested = int(args.num_nodes)
if requested > 0:
if requested > len(node_entries):
print(
{
"requested_nodes": int(requested),
"available_nodes": int(len(node_entries)),
"note": "requested nodes exceed available hosts; capping",
}
)
node_entries = node_entries[:requested]
world_size = len(node_entries)