mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: bulk tpu reorchestration
This commit is contained in:
@@ -92,9 +92,9 @@ def _truthy(value: str | bool | None) -> bool:
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _alive_nodes() -> list[tuple[str, str]]:
|
||||
def _alive_nodes() -> list[tuple[str, str, bool, float]]:
|
||||
seen: set[str] = set()
|
||||
nodes: list[tuple[str, str]] = []
|
||||
nodes: list[tuple[str, str, bool, float]] = []
|
||||
for node in ray.nodes():
|
||||
if not bool(node.get("Alive", False)):
|
||||
continue
|
||||
@@ -102,9 +102,50 @@ def _alive_nodes() -> list[tuple[str, str]]:
|
||||
ip = str(node.get("NodeManagerAddress", "")).strip()
|
||||
if not node_id or not ip or node_id in seen:
|
||||
continue
|
||||
resources = node.get("Resources", {}) or {}
|
||||
is_head = bool(resources.get("node:__internal_head__", 0.0))
|
||||
tpu = float(resources.get("TPU", 0.0))
|
||||
seen.add(node_id)
|
||||
nodes.append((node_id, ip))
|
||||
return sorted(nodes, key=lambda item: (item[1], item[0]))
|
||||
nodes.append((node_id, ip, is_head, tpu))
|
||||
return sorted(nodes, key=lambda item: (item[1], item[2], -item[3], item[0]))
|
||||
|
||||
|
||||
def _dedupe_nodes_for_tpu(
|
||||
nodes: list[tuple[str, str, bool, float]],
|
||||
) -> tuple[list[tuple[str, str]], list[dict[str, str | float | bool]]]:
|
||||
selected: dict[str, tuple[str, str, bool, float]] = {}
|
||||
dropped: list[dict[str, str | float | bool]] = []
|
||||
|
||||
def _score(item: tuple[str, str, bool, float]) -> tuple[int, float, str]:
|
||||
node_id, _ip, is_head, tpu = item
|
||||
return (1 if bool(is_head) else 0, -float(tpu), str(node_id))
|
||||
|
||||
for item in nodes:
|
||||
node_id, ip, is_head, tpu = item
|
||||
existing = selected.get(ip)
|
||||
if existing is None:
|
||||
selected[ip] = item
|
||||
continue
|
||||
|
||||
keep, drop = (
|
||||
(item, existing) if _score(item) < _score(existing) else (existing, item)
|
||||
)
|
||||
selected[ip] = keep
|
||||
dropped.append(
|
||||
{
|
||||
"ip": str(ip),
|
||||
"dropped_node_id": str(drop[0]),
|
||||
"dropped_is_head": bool(drop[2]),
|
||||
"dropped_tpu": float(drop[3]),
|
||||
"kept_node_id": str(keep[0]),
|
||||
"kept_is_head": bool(keep[2]),
|
||||
"kept_tpu": float(keep[3]),
|
||||
}
|
||||
)
|
||||
|
||||
entries = [(node_id, ip) for ip, (node_id, _ip, _is_head, _tpu) in selected.items()]
|
||||
entries.sort(key=lambda item: (item[1], item[0]))
|
||||
return entries, dropped
|
||||
|
||||
|
||||
def _benchmark_cells(
|
||||
@@ -369,8 +410,19 @@ def _train_on_node(
|
||||
env["JAX_PLATFORM_NAME"] = "cpu"
|
||||
else:
|
||||
env.pop("JAX_PLATFORM_NAME", None)
|
||||
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||
if requested_platform == "tpu" and world_size > 1 and allow_multi_node_tpu:
|
||||
env["CLOUD_TPU_TASK_ID"] = str(int(rank))
|
||||
print(
|
||||
{
|
||||
"rank": int(rank),
|
||||
"node_ip": str(node_ip),
|
||||
"jax_platform": "tpu",
|
||||
"cloud_tpu_task_id": str(env["CLOUD_TPU_TASK_ID"]),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Keep each process in single-host mode when TPU multi-host is disabled.
|
||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||
if run_kind == "benchmark":
|
||||
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
|
||||
if wandb_entity:
|
||||
@@ -508,12 +560,36 @@ def main() -> None:
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
node_entries = _alive_nodes()
|
||||
if not node_entries:
|
||||
node_records = _alive_nodes()
|
||||
if not node_records:
|
||||
raise RuntimeError("No alive Ray nodes found")
|
||||
|
||||
if float(args.tpu_per_task) > 0.0:
|
||||
node_entries, dropped = _dedupe_nodes_for_tpu(node_records)
|
||||
if dropped:
|
||||
print(
|
||||
{
|
||||
"tpu_host_dedupe": True,
|
||||
"alive_ray_nodes": len(node_records),
|
||||
"unique_tpu_hosts": len(node_entries),
|
||||
"dropped": dropped,
|
||||
}
|
||||
)
|
||||
else:
|
||||
node_entries = [
|
||||
(node_id, node_ip) for node_id, node_ip, _is_head, _tpu in node_records
|
||||
]
|
||||
|
||||
requested = int(args.num_nodes)
|
||||
if requested > 0:
|
||||
if requested > len(node_entries):
|
||||
print(
|
||||
{
|
||||
"requested_nodes": int(requested),
|
||||
"available_nodes": int(len(node_entries)),
|
||||
"note": "requested nodes exceed available hosts; capping",
|
||||
}
|
||||
)
|
||||
node_entries = node_entries[:requested]
|
||||
|
||||
world_size = len(node_entries)
|
||||
|
||||
Reference in New Issue
Block a user