chore: refactor for sweeps and IP configs

This commit is contained in:
2026-03-12 12:48:52 +01:00
parent b1f583be39
commit 88155d22a7
5 changed files with 566 additions and 34 deletions

View File

@@ -1,14 +1,18 @@
from __future__ import annotations
import argparse
import contextlib
import concurrent.futures
import os
import shlex
import subprocess
import sys
import threading
import time
from pathlib import Path
import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
def _has_flag(tokens: list[str], name: str) -> bool:
@@ -24,18 +28,301 @@ def _entry_tokens(run_kind: str, entry_args: str) -> list[str]:
return tokens
def _alive_node_ips() -> list[str]:
def _get_flag_value(tokens: list[str], name: str, default: str = "") -> str:
for idx, tok in enumerate(tokens):
if tok == name and idx + 1 < len(tokens):
return str(tokens[idx + 1])
if tok.startswith(f"{name}="):
return str(tok.split("=", 1)[1])
return str(default)
def _set_flag_value(tokens: list[str], name: str, value: str) -> list[str]:
updated: list[str] = []
replaced = False
idx = 0
while idx < len(tokens):
tok = tokens[idx]
if tok == name:
replaced = True
updated.extend([name, str(value)])
idx += 2
continue
if tok.startswith(f"{name}="):
replaced = True
updated.append(f"{name}={value}")
idx += 1
continue
updated.append(tok)
idx += 1
if not replaced:
updated.extend([name, str(value)])
return updated
def _remove_flag(tokens: list[str], name: str) -> list[str]:
updated: list[str] = []
idx = 0
while idx < len(tokens):
tok = tokens[idx]
if tok == name:
idx += 1
continue
if tok.startswith(f"{name}="):
idx += 1
continue
updated.append(tok)
idx += 1
return updated
def _csv_values(raw: str) -> list[str]:
return [piece.strip() for piece in str(raw).split(",") if piece.strip()]
def _alpha_token(alpha: str) -> str:
return str(alpha).replace(".", "p").replace("-", "m")
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _alive_nodes() -> list[tuple[str, str]]:
seen: set[str] = set()
ips: list[str] = []
nodes: list[tuple[str, str]] = []
for node in ray.nodes():
if not bool(node.get("Alive", False)):
continue
node_id = str(node.get("NodeID", "")).strip()
ip = str(node.get("NodeManagerAddress", "")).strip()
if not ip or ip in seen:
if not node_id or not ip or node_id in seen:
continue
seen.add(ip)
ips.append(ip)
return sorted(ips)
seen.add(node_id)
nodes.append((node_id, ip))
return sorted(nodes, key=lambda item: (item[1], item[0]))
def _benchmark_cells(
tokens: list[str], *, compare_robust: bool
) -> list[tuple[str, str, str, bool]]:
tiers = _csv_values(
_get_flag_value(tokens, "--tiers", "static,surge,linear,qtable,ppo")
)
alphas = _csv_values(_get_flag_value(tokens, "--alpha-values", "0.0,0.3,0.6"))
base_no_robust = _has_flag(tokens, "--no-robust")
if compare_robust:
modes = [("robust", False), ("no_robust", True)]
else:
modes = [("no_robust", True)] if base_no_robust else [("robust", False)]
return [
(tier, alpha, mode_label, no_robust)
for tier in tiers
for alpha in alphas
for mode_label, no_robust in modes
]
def _thread_limited_env(env: dict[str, str], threads: int) -> dict[str, str]:
bounded = dict(env)
n = str(max(1, int(threads)))
for key in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
"BLIS_NUM_THREADS",
):
bounded[key] = n
return bounded
@contextlib.contextmanager
def _semaphore_guard(semaphore: threading.Semaphore | None):
if semaphore is None:
yield
return
semaphore.acquire()
try:
yield
finally:
semaphore.release()
def _run_benchmark_cells_parallel(
*,
root: str,
env: dict[str, str],
base_tokens: list[str],
compare_robust: bool,
inner_workers: int,
inner_threads: int,
max_heavy_workers: int,
rank: int,
) -> int:
cells = _benchmark_cells(base_tokens, compare_robust=compare_robust)
if not cells:
return 0
cwd = str(Path(root))
base_out = _get_flag_value(base_tokens, "--output-dir", "engine/studies/results")
max_workers = max(1, min(int(inner_workers), len(cells)))
heavy_tiers = {"ppo", "a2c", "dqn"}
heavy_limit = max(1, int(max_heavy_workers))
heavy_sem = threading.Semaphore(heavy_limit)
print(
{
"rank": int(rank),
"benchmark_cells": len(cells),
"inner_workers": int(max_workers),
"inner_threads": int(max(1, int(inner_threads))),
"heavy_limit": int(heavy_limit),
}
)
def _run_cell(
index: int,
total: int,
tier: str,
alpha: str,
mode_label: str,
no_robust: bool,
) -> tuple[str, str, str, int]:
tokens = list(base_tokens)
tokens = _set_flag_value(tokens, "--tiers", tier)
tokens = _set_flag_value(tokens, "--alpha-values", alpha)
if no_robust:
if not _has_flag(tokens, "--no-robust"):
tokens.append("--no-robust")
else:
tokens = _remove_flag(tokens, "--no-robust")
cell_out = (
Path(base_out)
/ f"tier_{tier}"
/ f"mode_{mode_label}"
/ f"alpha_{_alpha_token(alpha)}"
)
tokens = _set_flag_value(tokens, "--output-dir", str(cell_out))
cmd = [sys.executable, "-m", "engine.train", *tokens]
cell_env = _thread_limited_env(env, int(inner_threads))
cell_env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "0"
print(
{
"rank": int(rank),
"cell": f"{index}/{total}",
"tier": tier,
"mode": mode_label,
"alpha": alpha,
"command": " ".join(cmd),
}
)
heavy_guard = heavy_sem if str(tier).lower() in heavy_tiers else None
with _semaphore_guard(heavy_guard):
proc = subprocess.run(cmd, cwd=cwd, env=cell_env)
return tier, alpha, mode_label, int(proc.returncode)
failures: list[tuple[str, str, str, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [
pool.submit(_run_cell, idx, len(cells), tier, alpha, mode_label, no_robust)
for idx, (tier, alpha, mode_label, no_robust) in enumerate(cells, start=1)
]
for fut in concurrent.futures.as_completed(futures):
tier, alpha, mode_label, code = fut.result()
if code != 0:
failures.append((tier, alpha, mode_label, code))
if failures:
print({"rank": int(rank), "benchmark_failures": failures})
return 1
return 0
def _run_sweep_agents_parallel(
*,
root: str,
env: dict[str, str],
base_tokens: list[str],
run_kind: str,
rank: int,
agents_per_node: int,
agent_count: int,
inner_threads: int,
tpu_agent_slots: int,
) -> int:
total = max(1, int(agents_per_node))
cwd = str(Path(root))
wants_tpu = str(env.get("JAX_PLATFORMS", "")).strip().lower() == "tpu"
tpu_slots = max(0, int(tpu_agent_slots))
print(
{
"rank": int(rank),
"sweep_agents": int(total),
"agent_count": int(agent_count),
"inner_threads": int(max(1, int(inner_threads))),
"jax_platform": str(env.get("JAX_PLATFORMS", "")),
"tpu_agent_slots": int(tpu_slots),
}
)
def _run_agent(slot: int) -> int:
tokens = list(base_tokens)
if int(agent_count) > 0 and not _has_flag(tokens, "--count"):
tokens.extend(["--count", str(int(agent_count))])
if _has_flag(tokens, "--group"):
base_group = _get_flag_value(tokens, "--group", "ray-sweep")
tokens = _set_flag_value(tokens, "--group", f"{base_group}-a{slot}")
if run_kind == "benchmark":
out_dir = _get_flag_value(tokens, "--output-dir", "engine/studies/results")
tokens = _set_flag_value(
tokens, "--output-dir", str(Path(out_dir) / f"agent_{slot}")
)
if run_kind == "train":
model_dir = _get_flag_value(tokens, "--model-dir", "engine/models")
tokens = _set_flag_value(
tokens, "--model-dir", str(Path(model_dir) / f"agent_{slot}")
)
cmd = [sys.executable, "-m", "engine.train", *tokens]
agent_env = _thread_limited_env(env, int(inner_threads))
if wants_tpu and tpu_slots > 0 and int(slot) > tpu_slots:
agent_env["JAX_PLATFORMS"] = "cpu"
agent_env["JAX_PLATFORM_NAME"] = "cpu"
agent_env["PHANTOM_SWEEP_AGENT_SLOT"] = str(int(slot))
print(
{
"rank": int(rank),
"agent_slot": int(slot),
"jax_platform": str(agent_env.get("JAX_PLATFORMS", "")),
"command": " ".join(cmd),
}
)
proc = subprocess.run(cmd, cwd=cwd, env=agent_env)
return int(proc.returncode)
failures: list[tuple[int, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=total) as pool:
future_map = {
pool.submit(_run_agent, slot): slot for slot in range(1, total + 1)
}
for future in concurrent.futures.as_completed(future_map):
slot = int(future_map[future])
code = int(future.result())
if code != 0:
failures.append((slot, code))
if failures:
print({"rank": int(rank), "sweep_failures": failures})
return 1
return 0
@ray.remote(max_retries=0)
@@ -44,6 +331,8 @@ def _train_on_node(
root: str,
run_kind: str,
entry_args: str,
node_id: str,
node_ip: str,
rank: int,
world_size: int,
coordinator_ip: str,
@@ -54,17 +343,32 @@ def _train_on_node(
output_root: str,
wandb_entity: str,
wandb_project: str,
agents_per_node: int,
agent_count: int,
inner_workers: int,
inner_threads: int,
max_heavy_workers: int,
sync_jax: bool,
) -> int:
env = dict(os.environ)
env["PYTHONUNBUFFERED"] = "1"
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
if world_size > 1 and requested_platform == "tpu":
allow_multi_node_tpu = _truthy(env.get("PHANTOM_ALLOW_MULTI_NODE_TPU"))
if world_size > 1 and requested_platform == "tpu" and not allow_multi_node_tpu:
requested_platform = "cpu"
print(
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs"
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs "
"(set PHANTOM_ALLOW_MULTI_NODE_TPU=1 to keep TPU for JAX workloads)"
)
elif world_size > 1 and requested_platform == "tpu" and allow_multi_node_tpu:
print(
"PHANTOM_DISTRIBUTED_NOTE: keeping JAX_PLATFORMS=tpu in multi-node mixed mode"
)
env["JAX_PLATFORMS"] = requested_platform
if requested_platform == "cpu":
env["JAX_PLATFORM_NAME"] = "cpu"
else:
env.pop("JAX_PLATFORM_NAME", None)
# Keep each train process in single-host mode to avoid accidental global stalls.
env["CLOUD_TPU_TASK_ID"] = "0"
if run_kind == "benchmark":
@@ -96,13 +400,29 @@ def _train_on_node(
)
tokens = _entry_tokens(run_kind, entry_args)
is_sweep_agent = _has_flag(tokens, "--sweep-agent")
seed = int(base_seed + rank)
if not _has_flag(tokens, "--seed"):
if not is_sweep_agent and not _has_flag(tokens, "--seed"):
tokens.extend(["--seed", str(seed)])
if run_kind == "train" and not _has_flag(tokens, "--group"):
tokens.extend(["--group", run_group])
if is_sweep_agent and int(agent_count) > 0 and not _has_flag(tokens, "--count"):
tokens.extend(["--count", str(int(agent_count))])
try:
tpu_agent_slots = int(
str(
env.get(
"PHANTOM_TPU_AGENT_SLOTS",
"1" if requested_platform == "tpu" else "0",
)
).strip()
)
except ValueError:
tpu_agent_slots = 1 if requested_platform == "tpu" else 0
if (
run_kind == "benchmark"
and output_root
@@ -112,9 +432,36 @@ def _train_on_node(
out_dir.parent.mkdir(parents=True, exist_ok=True)
tokens.extend(["--output-dir", str(out_dir)])
if is_sweep_agent and int(agents_per_node) > 1:
return _run_sweep_agents_parallel(
root=root,
env=env,
base_tokens=tokens,
run_kind=run_kind,
rank=rank,
agents_per_node=int(agents_per_node),
agent_count=int(agent_count),
inner_threads=int(inner_threads),
tpu_agent_slots=int(max(0, tpu_agent_slots)),
)
if run_kind == "benchmark" and int(inner_workers) > 1 and not is_sweep_agent:
return _run_benchmark_cells_parallel(
root=root,
env=env,
base_tokens=tokens,
compare_robust=bool(compare_robust),
inner_workers=int(inner_workers),
inner_threads=int(inner_threads),
max_heavy_workers=int(max_heavy_workers),
rank=rank,
)
cmd = [sys.executable, "-m", "engine.train", *tokens]
print(
{
"node_id": node_id,
"node_ip": node_ip,
"rank": int(rank),
"run_kind": run_kind,
"seed": int(seed),
@@ -124,7 +471,9 @@ def _train_on_node(
"command": " ".join(cmd),
}
)
proc = subprocess.run(cmd, cwd=cwd, env=env)
proc = subprocess.run(
cmd, cwd=cwd, env=_thread_limited_env(env, int(inner_threads))
)
return int(proc.returncode)
@@ -145,6 +494,12 @@ def main() -> None:
parser.add_argument("--output-root", type=str, default="")
parser.add_argument("--wandb-entity", type=str, default="")
parser.add_argument("--wandb-project", type=str, default="")
parser.add_argument("--agents-per-node", type=int, default=1)
parser.add_argument("--agent-count", type=int, default=0)
parser.add_argument("--inner-workers", type=int, default=1)
parser.add_argument("--inner-threads", type=int, default=1)
parser.add_argument("--max-heavy-workers", type=int, default=2)
parser.add_argument("--worker-cpus", type=float, default=1.0)
args = parser.parse_args()
entry_args = str(args.entry_args or args.train_args).strip()
@@ -153,21 +508,24 @@ def main() -> None:
ray.init(address="auto")
node_ips = _alive_node_ips()
if not node_ips:
node_entries = _alive_nodes()
if not node_entries:
raise RuntimeError("No alive Ray nodes found")
requested = int(args.num_nodes)
if requested > 0:
node_ips = node_ips[:requested]
node_entries = node_entries[:requested]
world_size = len(node_ips)
coordinator_ip = node_ips[0]
world_size = len(node_entries)
coordinator_ip = node_entries[0][1]
run_group = args.run_group or f"ray-dist-{int(time.time())}"
print(
{
"nodes": node_ips,
"nodes": [
{"node_id": node_id, "node_ip": node_ip}
for node_id, node_ip in node_entries
],
"world_size": world_size,
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
"run_kind": str(args.run_kind),
@@ -175,18 +533,35 @@ def main() -> None:
"run_group": run_group,
"compare_robust": bool(args.compare_robust),
"output_root": str(args.output_root),
"agents_per_node": int(args.agents_per_node),
"agent_count": int(args.agent_count),
"inner_workers": int(args.inner_workers),
"inner_threads": int(args.inner_threads),
"max_heavy_workers": int(args.max_heavy_workers),
}
)
futures = []
root = str(Path(__file__).resolve().parents[1])
for rank, node_ip in enumerate(node_ips):
resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)}
for rank, (node_id, node_ip) in enumerate(node_entries):
resources: dict[str, float] = {}
tpu_per_task = float(args.tpu_per_task)
if tpu_per_task > 0.0:
resources["TPU"] = tpu_per_task
futures.append(
_train_on_node.options(resources=resources).remote(
_train_on_node.options(
resources=resources,
num_cpus=float(args.worker_cpus),
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
root=root,
run_kind=str(args.run_kind),
entry_args=entry_args,
node_id=node_id,
node_ip=node_ip,
rank=rank,
world_size=world_size,
coordinator_ip=coordinator_ip,
@@ -197,6 +572,11 @@ def main() -> None:
output_root=str(args.output_root),
wandb_entity=str(args.wandb_entity),
wandb_project=str(args.wandb_project),
agents_per_node=int(args.agents_per_node),
agent_count=int(args.agent_count),
inner_workers=int(args.inner_workers),
inner_threads=int(args.inner_threads),
max_heavy_workers=int(args.max_heavy_workers),
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
)
)