chore: refactor for sweeps and IP configs

This commit is contained in:
2026-03-12 12:48:52 +01:00
parent b1f583be39
commit 88155d22a7
5 changed files with 566 additions and 34 deletions

View File

@@ -10,9 +10,13 @@ services:
- HF_TOKEN=${HF_TOKEN} - HF_TOKEN=${HF_TOKEN}
- WANDB_API_KEY=${WANDB_API_KEY} - WANDB_API_KEY=${WANDB_API_KEY}
- GITHUB_TOKEN=${GITHUB_TOKEN} - GITHUB_TOKEN=${GITHUB_TOKEN}
- GOOGLE_APPLICATION_CREDENTIALS=/secrets/gcp-sa.json
- GCP_ACCOUNT=${GCP_ACCOUNT:-}
- WATCHDOG_CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-v6e_*.conf}
- CLOUDSDK_CONFIG=/.config/gcloud - CLOUDSDK_CONFIG=/.config/gcloud
volumes: volumes:
- ~/.config/gcloud:/.config/gcloud:rw - ~/.config/gcloud:/.config/gcloud:rw
- ./secrets/gcp-sa.json:/secrets/gcp-sa.json:ro
tensorboard-rl: tensorboard-rl:
image: tensorflow/tensorflow:latest image: tensorflow/tensorflow:latest

View File

@@ -35,25 +35,55 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT
if [ "$CRED_TYPE" = "service_account" ]; then if [ "$CRED_TYPE" = "service_account" ]; then
echo "Authenticating gcloud using service account key..." echo "Authenticating gcloud using service account key..."
gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS" gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS"
# Extract project ID from the key file if [ -z "$PROJECT_ID" ]; then
PROJECT_ID=$(jq -r '.project_id' "$GOOGLE_APPLICATION_CREDENTIALS") PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
if [ -n "$PROJECT_ID" ] && [ "$PROJECT_ID" != "null" ]; then
gcloud config set project "$PROJECT_ID"
echo "Set project to $PROJECT_ID"
fi fi
elif [ "$CRED_TYPE" = "authorized_user" ]; then
echo "Authenticating gcloud using authorized_user refresh token..."
AUTH_ACCOUNT="$GCP_ACCOUNT"
if [ -z "$AUTH_ACCOUNT" ]; then
AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
fi
if [ -z "$AUTH_ACCOUNT" ]; then
AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
fi
REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then
echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token."
exit 1
fi
gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN"
else else
echo "Note: Using application default credentials or mounted gcloud config..." echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config."
fi fi
else else
echo "Note: Assuming gcloud config is mounted from host." echo "Note: Assuming gcloud config is mounted from host."
fi fi
if [ -n "$PROJECT_ID" ]; then
gcloud config set project "$PROJECT_ID"
echo "Set project to $PROJECT_ID"
fi
# Run the watchdogs in the background using bash instead of tmux # Run the watchdogs in the background using bash instead of tmux
# Tmux needs a TTY to attach properly which we might not have in docker # Tmux needs a TTY to attach properly which we might not have in docker
# Stagger startups by 15s to prevent simultaneous TPU creation quota hits # Stagger startups by 15s to prevent simultaneous TPU creation quota hits
CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-"*.conf"}
shopt -s nullglob
CONFIGS=(/app/tpu_orchestration/configs/$CONFIG_PATTERN)
if [ ${#CONFIGS[@]} -eq 0 ]; then
echo "Error: no watchdog configs matched pattern '$CONFIG_PATTERN'."
exit 1
fi
echo "Using watchdog config pattern: $CONFIG_PATTERN"
DELAY=0 DELAY=0
for conf in /app/tpu_orchestration/configs/*.conf; do for conf in "${CONFIGS[@]}"; do
echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)" echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)"
(sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") & (sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") &
DELAY=$((DELAY + 15)) DELAY=$((DELAY + 15))
@@ -67,4 +97,4 @@ EOF
RUN chmod +x /app/entrypoint.sh RUN chmod +x /app/entrypoint.sh
CMD ["/app/entrypoint.sh"] CMD ["/app/entrypoint.sh"]

View File

@@ -1,14 +1,18 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import contextlib
import concurrent.futures
import os import os
import shlex import shlex
import subprocess import subprocess
import sys import sys
import threading
import time import time
from pathlib import Path from pathlib import Path
import ray import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
def _has_flag(tokens: list[str], name: str) -> bool: def _has_flag(tokens: list[str], name: str) -> bool:
@@ -24,18 +28,301 @@ def _entry_tokens(run_kind: str, entry_args: str) -> list[str]:
return tokens return tokens
def _alive_node_ips() -> list[str]: def _get_flag_value(tokens: list[str], name: str, default: str = "") -> str:
for idx, tok in enumerate(tokens):
if tok == name and idx + 1 < len(tokens):
return str(tokens[idx + 1])
if tok.startswith(f"{name}="):
return str(tok.split("=", 1)[1])
return str(default)
def _set_flag_value(tokens: list[str], name: str, value: str) -> list[str]:
updated: list[str] = []
replaced = False
idx = 0
while idx < len(tokens):
tok = tokens[idx]
if tok == name:
replaced = True
updated.extend([name, str(value)])
idx += 2
continue
if tok.startswith(f"{name}="):
replaced = True
updated.append(f"{name}={value}")
idx += 1
continue
updated.append(tok)
idx += 1
if not replaced:
updated.extend([name, str(value)])
return updated
def _remove_flag(tokens: list[str], name: str) -> list[str]:
updated: list[str] = []
idx = 0
while idx < len(tokens):
tok = tokens[idx]
if tok == name:
idx += 1
continue
if tok.startswith(f"{name}="):
idx += 1
continue
updated.append(tok)
idx += 1
return updated
def _csv_values(raw: str) -> list[str]:
return [piece.strip() for piece in str(raw).split(",") if piece.strip()]
def _alpha_token(alpha: str) -> str:
return str(alpha).replace(".", "p").replace("-", "m")
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _alive_nodes() -> list[tuple[str, str]]:
seen: set[str] = set() seen: set[str] = set()
ips: list[str] = [] nodes: list[tuple[str, str]] = []
for node in ray.nodes(): for node in ray.nodes():
if not bool(node.get("Alive", False)): if not bool(node.get("Alive", False)):
continue continue
node_id = str(node.get("NodeID", "")).strip()
ip = str(node.get("NodeManagerAddress", "")).strip() ip = str(node.get("NodeManagerAddress", "")).strip()
if not ip or ip in seen: if not node_id or not ip or node_id in seen:
continue continue
seen.add(ip) seen.add(node_id)
ips.append(ip) nodes.append((node_id, ip))
return sorted(ips) return sorted(nodes, key=lambda item: (item[1], item[0]))
def _benchmark_cells(
tokens: list[str], *, compare_robust: bool
) -> list[tuple[str, str, str, bool]]:
tiers = _csv_values(
_get_flag_value(tokens, "--tiers", "static,surge,linear,qtable,ppo")
)
alphas = _csv_values(_get_flag_value(tokens, "--alpha-values", "0.0,0.3,0.6"))
base_no_robust = _has_flag(tokens, "--no-robust")
if compare_robust:
modes = [("robust", False), ("no_robust", True)]
else:
modes = [("no_robust", True)] if base_no_robust else [("robust", False)]
return [
(tier, alpha, mode_label, no_robust)
for tier in tiers
for alpha in alphas
for mode_label, no_robust in modes
]
def _thread_limited_env(env: dict[str, str], threads: int) -> dict[str, str]:
bounded = dict(env)
n = str(max(1, int(threads)))
for key in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
"BLIS_NUM_THREADS",
):
bounded[key] = n
return bounded
@contextlib.contextmanager
def _semaphore_guard(semaphore: threading.Semaphore | None):
if semaphore is None:
yield
return
semaphore.acquire()
try:
yield
finally:
semaphore.release()
def _run_benchmark_cells_parallel(
*,
root: str,
env: dict[str, str],
base_tokens: list[str],
compare_robust: bool,
inner_workers: int,
inner_threads: int,
max_heavy_workers: int,
rank: int,
) -> int:
cells = _benchmark_cells(base_tokens, compare_robust=compare_robust)
if not cells:
return 0
cwd = str(Path(root))
base_out = _get_flag_value(base_tokens, "--output-dir", "engine/studies/results")
max_workers = max(1, min(int(inner_workers), len(cells)))
heavy_tiers = {"ppo", "a2c", "dqn"}
heavy_limit = max(1, int(max_heavy_workers))
heavy_sem = threading.Semaphore(heavy_limit)
print(
{
"rank": int(rank),
"benchmark_cells": len(cells),
"inner_workers": int(max_workers),
"inner_threads": int(max(1, int(inner_threads))),
"heavy_limit": int(heavy_limit),
}
)
def _run_cell(
index: int,
total: int,
tier: str,
alpha: str,
mode_label: str,
no_robust: bool,
) -> tuple[str, str, str, int]:
tokens = list(base_tokens)
tokens = _set_flag_value(tokens, "--tiers", tier)
tokens = _set_flag_value(tokens, "--alpha-values", alpha)
if no_robust:
if not _has_flag(tokens, "--no-robust"):
tokens.append("--no-robust")
else:
tokens = _remove_flag(tokens, "--no-robust")
cell_out = (
Path(base_out)
/ f"tier_{tier}"
/ f"mode_{mode_label}"
/ f"alpha_{_alpha_token(alpha)}"
)
tokens = _set_flag_value(tokens, "--output-dir", str(cell_out))
cmd = [sys.executable, "-m", "engine.train", *tokens]
cell_env = _thread_limited_env(env, int(inner_threads))
cell_env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "0"
print(
{
"rank": int(rank),
"cell": f"{index}/{total}",
"tier": tier,
"mode": mode_label,
"alpha": alpha,
"command": " ".join(cmd),
}
)
heavy_guard = heavy_sem if str(tier).lower() in heavy_tiers else None
with _semaphore_guard(heavy_guard):
proc = subprocess.run(cmd, cwd=cwd, env=cell_env)
return tier, alpha, mode_label, int(proc.returncode)
failures: list[tuple[str, str, str, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [
pool.submit(_run_cell, idx, len(cells), tier, alpha, mode_label, no_robust)
for idx, (tier, alpha, mode_label, no_robust) in enumerate(cells, start=1)
]
for fut in concurrent.futures.as_completed(futures):
tier, alpha, mode_label, code = fut.result()
if code != 0:
failures.append((tier, alpha, mode_label, code))
if failures:
print({"rank": int(rank), "benchmark_failures": failures})
return 1
return 0
def _run_sweep_agents_parallel(
*,
root: str,
env: dict[str, str],
base_tokens: list[str],
run_kind: str,
rank: int,
agents_per_node: int,
agent_count: int,
inner_threads: int,
tpu_agent_slots: int,
) -> int:
total = max(1, int(agents_per_node))
cwd = str(Path(root))
wants_tpu = str(env.get("JAX_PLATFORMS", "")).strip().lower() == "tpu"
tpu_slots = max(0, int(tpu_agent_slots))
print(
{
"rank": int(rank),
"sweep_agents": int(total),
"agent_count": int(agent_count),
"inner_threads": int(max(1, int(inner_threads))),
"jax_platform": str(env.get("JAX_PLATFORMS", "")),
"tpu_agent_slots": int(tpu_slots),
}
)
def _run_agent(slot: int) -> int:
tokens = list(base_tokens)
if int(agent_count) > 0 and not _has_flag(tokens, "--count"):
tokens.extend(["--count", str(int(agent_count))])
if _has_flag(tokens, "--group"):
base_group = _get_flag_value(tokens, "--group", "ray-sweep")
tokens = _set_flag_value(tokens, "--group", f"{base_group}-a{slot}")
if run_kind == "benchmark":
out_dir = _get_flag_value(tokens, "--output-dir", "engine/studies/results")
tokens = _set_flag_value(
tokens, "--output-dir", str(Path(out_dir) / f"agent_{slot}")
)
if run_kind == "train":
model_dir = _get_flag_value(tokens, "--model-dir", "engine/models")
tokens = _set_flag_value(
tokens, "--model-dir", str(Path(model_dir) / f"agent_{slot}")
)
cmd = [sys.executable, "-m", "engine.train", *tokens]
agent_env = _thread_limited_env(env, int(inner_threads))
if wants_tpu and tpu_slots > 0 and int(slot) > tpu_slots:
agent_env["JAX_PLATFORMS"] = "cpu"
agent_env["JAX_PLATFORM_NAME"] = "cpu"
agent_env["PHANTOM_SWEEP_AGENT_SLOT"] = str(int(slot))
print(
{
"rank": int(rank),
"agent_slot": int(slot),
"jax_platform": str(agent_env.get("JAX_PLATFORMS", "")),
"command": " ".join(cmd),
}
)
proc = subprocess.run(cmd, cwd=cwd, env=agent_env)
return int(proc.returncode)
failures: list[tuple[int, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=total) as pool:
future_map = {
pool.submit(_run_agent, slot): slot for slot in range(1, total + 1)
}
for future in concurrent.futures.as_completed(future_map):
slot = int(future_map[future])
code = int(future.result())
if code != 0:
failures.append((slot, code))
if failures:
print({"rank": int(rank), "sweep_failures": failures})
return 1
return 0
@ray.remote(max_retries=0) @ray.remote(max_retries=0)
@@ -44,6 +331,8 @@ def _train_on_node(
root: str, root: str,
run_kind: str, run_kind: str,
entry_args: str, entry_args: str,
node_id: str,
node_ip: str,
rank: int, rank: int,
world_size: int, world_size: int,
coordinator_ip: str, coordinator_ip: str,
@@ -54,17 +343,32 @@ def _train_on_node(
output_root: str, output_root: str,
wandb_entity: str, wandb_entity: str,
wandb_project: str, wandb_project: str,
agents_per_node: int,
agent_count: int,
inner_workers: int,
inner_threads: int,
max_heavy_workers: int,
sync_jax: bool, sync_jax: bool,
) -> int: ) -> int:
env = dict(os.environ) env = dict(os.environ)
env["PYTHONUNBUFFERED"] = "1" env["PYTHONUNBUFFERED"] = "1"
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower() requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
if world_size > 1 and requested_platform == "tpu": allow_multi_node_tpu = _truthy(env.get("PHANTOM_ALLOW_MULTI_NODE_TPU"))
if world_size > 1 and requested_platform == "tpu" and not allow_multi_node_tpu:
requested_platform = "cpu" requested_platform = "cpu"
print( print(
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs" "PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs "
"(set PHANTOM_ALLOW_MULTI_NODE_TPU=1 to keep TPU for JAX workloads)"
)
elif world_size > 1 and requested_platform == "tpu" and allow_multi_node_tpu:
print(
"PHANTOM_DISTRIBUTED_NOTE: keeping JAX_PLATFORMS=tpu in multi-node mixed mode"
) )
env["JAX_PLATFORMS"] = requested_platform env["JAX_PLATFORMS"] = requested_platform
if requested_platform == "cpu":
env["JAX_PLATFORM_NAME"] = "cpu"
else:
env.pop("JAX_PLATFORM_NAME", None)
# Keep each train process in single-host mode to avoid accidental global stalls. # Keep each train process in single-host mode to avoid accidental global stalls.
env["CLOUD_TPU_TASK_ID"] = "0" env["CLOUD_TPU_TASK_ID"] = "0"
if run_kind == "benchmark": if run_kind == "benchmark":
@@ -96,13 +400,29 @@ def _train_on_node(
) )
tokens = _entry_tokens(run_kind, entry_args) tokens = _entry_tokens(run_kind, entry_args)
is_sweep_agent = _has_flag(tokens, "--sweep-agent")
seed = int(base_seed + rank) seed = int(base_seed + rank)
if not _has_flag(tokens, "--seed"): if not is_sweep_agent and not _has_flag(tokens, "--seed"):
tokens.extend(["--seed", str(seed)]) tokens.extend(["--seed", str(seed)])
if run_kind == "train" and not _has_flag(tokens, "--group"): if run_kind == "train" and not _has_flag(tokens, "--group"):
tokens.extend(["--group", run_group]) tokens.extend(["--group", run_group])
if is_sweep_agent and int(agent_count) > 0 and not _has_flag(tokens, "--count"):
tokens.extend(["--count", str(int(agent_count))])
try:
tpu_agent_slots = int(
str(
env.get(
"PHANTOM_TPU_AGENT_SLOTS",
"1" if requested_platform == "tpu" else "0",
)
).strip()
)
except ValueError:
tpu_agent_slots = 1 if requested_platform == "tpu" else 0
if ( if (
run_kind == "benchmark" run_kind == "benchmark"
and output_root and output_root
@@ -112,9 +432,36 @@ def _train_on_node(
out_dir.parent.mkdir(parents=True, exist_ok=True) out_dir.parent.mkdir(parents=True, exist_ok=True)
tokens.extend(["--output-dir", str(out_dir)]) tokens.extend(["--output-dir", str(out_dir)])
if is_sweep_agent and int(agents_per_node) > 1:
return _run_sweep_agents_parallel(
root=root,
env=env,
base_tokens=tokens,
run_kind=run_kind,
rank=rank,
agents_per_node=int(agents_per_node),
agent_count=int(agent_count),
inner_threads=int(inner_threads),
tpu_agent_slots=int(max(0, tpu_agent_slots)),
)
if run_kind == "benchmark" and int(inner_workers) > 1 and not is_sweep_agent:
return _run_benchmark_cells_parallel(
root=root,
env=env,
base_tokens=tokens,
compare_robust=bool(compare_robust),
inner_workers=int(inner_workers),
inner_threads=int(inner_threads),
max_heavy_workers=int(max_heavy_workers),
rank=rank,
)
cmd = [sys.executable, "-m", "engine.train", *tokens] cmd = [sys.executable, "-m", "engine.train", *tokens]
print( print(
{ {
"node_id": node_id,
"node_ip": node_ip,
"rank": int(rank), "rank": int(rank),
"run_kind": run_kind, "run_kind": run_kind,
"seed": int(seed), "seed": int(seed),
@@ -124,7 +471,9 @@ def _train_on_node(
"command": " ".join(cmd), "command": " ".join(cmd),
} }
) )
proc = subprocess.run(cmd, cwd=cwd, env=env) proc = subprocess.run(
cmd, cwd=cwd, env=_thread_limited_env(env, int(inner_threads))
)
return int(proc.returncode) return int(proc.returncode)
@@ -145,6 +494,12 @@ def main() -> None:
parser.add_argument("--output-root", type=str, default="") parser.add_argument("--output-root", type=str, default="")
parser.add_argument("--wandb-entity", type=str, default="") parser.add_argument("--wandb-entity", type=str, default="")
parser.add_argument("--wandb-project", type=str, default="") parser.add_argument("--wandb-project", type=str, default="")
parser.add_argument("--agents-per-node", type=int, default=1)
parser.add_argument("--agent-count", type=int, default=0)
parser.add_argument("--inner-workers", type=int, default=1)
parser.add_argument("--inner-threads", type=int, default=1)
parser.add_argument("--max-heavy-workers", type=int, default=2)
parser.add_argument("--worker-cpus", type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
entry_args = str(args.entry_args or args.train_args).strip() entry_args = str(args.entry_args or args.train_args).strip()
@@ -153,21 +508,24 @@ def main() -> None:
ray.init(address="auto") ray.init(address="auto")
node_ips = _alive_node_ips() node_entries = _alive_nodes()
if not node_ips: if not node_entries:
raise RuntimeError("No alive Ray nodes found") raise RuntimeError("No alive Ray nodes found")
requested = int(args.num_nodes) requested = int(args.num_nodes)
if requested > 0: if requested > 0:
node_ips = node_ips[:requested] node_entries = node_entries[:requested]
world_size = len(node_ips) world_size = len(node_entries)
coordinator_ip = node_ips[0] coordinator_ip = node_entries[0][1]
run_group = args.run_group or f"ray-dist-{int(time.time())}" run_group = args.run_group or f"ray-dist-{int(time.time())}"
print( print(
{ {
"nodes": node_ips, "nodes": [
{"node_id": node_id, "node_ip": node_ip}
for node_id, node_ip in node_entries
],
"world_size": world_size, "world_size": world_size,
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}", "coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
"run_kind": str(args.run_kind), "run_kind": str(args.run_kind),
@@ -175,18 +533,35 @@ def main() -> None:
"run_group": run_group, "run_group": run_group,
"compare_robust": bool(args.compare_robust), "compare_robust": bool(args.compare_robust),
"output_root": str(args.output_root), "output_root": str(args.output_root),
"agents_per_node": int(args.agents_per_node),
"agent_count": int(args.agent_count),
"inner_workers": int(args.inner_workers),
"inner_threads": int(args.inner_threads),
"max_heavy_workers": int(args.max_heavy_workers),
} }
) )
futures = [] futures = []
root = str(Path(__file__).resolve().parents[1]) root = str(Path(__file__).resolve().parents[1])
for rank, node_ip in enumerate(node_ips): for rank, (node_id, node_ip) in enumerate(node_entries):
resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)} resources: dict[str, float] = {}
tpu_per_task = float(args.tpu_per_task)
if tpu_per_task > 0.0:
resources["TPU"] = tpu_per_task
futures.append( futures.append(
_train_on_node.options(resources=resources).remote( _train_on_node.options(
resources=resources,
num_cpus=float(args.worker_cpus),
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
root=root, root=root,
run_kind=str(args.run_kind), run_kind=str(args.run_kind),
entry_args=entry_args, entry_args=entry_args,
node_id=node_id,
node_ip=node_ip,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
coordinator_ip=coordinator_ip, coordinator_ip=coordinator_ip,
@@ -197,6 +572,11 @@ def main() -> None:
output_root=str(args.output_root), output_root=str(args.output_root),
wandb_entity=str(args.wandb_entity), wandb_entity=str(args.wandb_entity),
wandb_project=str(args.wandb_project), wandb_project=str(args.wandb_project),
agents_per_node=int(args.agents_per_node),
agent_count=int(args.agent_count),
inner_workers=int(args.inner_workers),
inner_threads=int(args.inner_threads),
max_heavy_workers=int(args.max_heavy_workers),
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"), sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
) )
) )

View File

@@ -4,6 +4,7 @@
# RAY_MODE=single -> one run (default) # RAY_MODE=single -> one run (default)
# RAY_MODE=distributed -> one run per TPU node (experimental) # RAY_MODE=distributed -> one run per TPU node (experimental)
# RAY_MODE=benchmark -> one benchmark run per TPU node (overnight) # RAY_MODE=benchmark -> one benchmark run per TPU node (overnight)
# RAY_MODE=sweep -> distributed W&B sweep agents
set -euo pipefail set -euo pipefail
@@ -28,7 +29,14 @@ env = dotenv_values(".env")
# Filter out empty/None values # Filter out empty/None values
env_vars = {k: v for k, v in env.items() if v} env_vars = {k: v for k, v in env.items() if v}
env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0")) env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0"))
for k in ("WANDB_ENTITY", "WANDB_PROJECT", "PHANTOM_BENCHMARK_COMPARE_ROBUST"): for k in (
"WANDB_ENTITY",
"WANDB_PROJECT",
"PHANTOM_BENCHMARK_COMPARE_ROBUST",
"PHANTOM_JAX_PLATFORM",
"PHANTOM_ALLOW_MULTI_NODE_TPU",
"PHANTOM_TPU_AGENT_SLOTS",
):
if os.getenv(k): if os.getenv(k):
env_vars[k] = os.getenv(k) env_vars[k] = os.getenv(k)
@@ -52,6 +60,15 @@ print(json.dumps({
RAY_MODE="${RAY_MODE:-single}" RAY_MODE="${RAY_MODE:-single}"
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}" TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}"
BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}" BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}"
INNER_WORKERS="${INNER_WORKERS:-16}"
INNER_THREADS="${INNER_THREADS:-1}"
MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}"
WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}"
SWEEP_KIND="${SWEEP_KIND:-benchmark}"
SWEEP_METHOD="${SWEEP_METHOD:-random}"
SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}"
AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}"
AGENT_COUNT="${AGENT_COUNT:-0}"
SUBMIT_ARGS=() SUBMIT_ARGS=()
if [ "${RAY_NO_WAIT:-0}" = "1" ]; then if [ "${RAY_NO_WAIT:-0}" = "1" ]; then
@@ -104,6 +121,10 @@ if [ "$RAY_MODE" = "benchmark" ]; then
--output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}" --output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}"
--wandb-entity "${WANDB_ENTITY:-lusiana}" --wandb-entity "${WANDB_ENTITY:-lusiana}"
--wandb-project "${WANDB_PROJECT:-capstone_tpu}" --wandb-project "${WANDB_PROJECT:-capstone_tpu}"
--inner-workers "${INNER_WORKERS}"
--inner-threads "${INNER_THREADS}"
--max-heavy-workers "${MAX_HEAVY_WORKERS}"
--worker-cpus "${WORKER_CPUS}"
) )
if [ "${COMPARE_ROBUST:-1}" = "1" ]; then if [ "${COMPARE_ROBUST:-1}" = "1" ]; then
DIST_ARGS+=(--compare-robust) DIST_ARGS+=(--compare-robust)
@@ -112,5 +133,97 @@ if [ "$RAY_MODE" = "benchmark" ]; then
exit 0 exit 0
fi fi
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', or 'benchmark')." >&2 if [ "$RAY_MODE" = "sweep" ]; then
SWEEP_PROJECT="${WANDB_PROJECT:-capstone_tpu}"
SWEEP_ENTITY="${WANDB_ENTITY:-lusiana}"
SWEEP_ID_VALUE="${SWEEP_ID:-}"
SWEEP_NUM_NODES="${NUM_NODES:-5}"
PY_SWEEP_BIN="${PY_SWEEP_BIN:-}"
if [ -z "$PY_SWEEP_BIN" ]; then
for cand in "$ROOT/.venv/bin/python" "$ROOT/.venv-ray/bin/python" python3 python; do
if [ "$cand" = "python3" ] || [ "$cand" = "python" ]; then
command -v "$cand" >/dev/null 2>&1 || continue
elif [ ! -x "$cand" ]; then
continue
fi
if "$cand" - <<'PY' >/dev/null 2>&1
import sys
from pathlib import Path
cwd = str(Path.cwd())
sys.path = [p for p in sys.path if p not in {'', cwd}]
import wandb
print(wandb.__name__)
PY
then
PY_SWEEP_BIN="$cand"
break
fi
done
fi
if [ -z "$PY_SWEEP_BIN" ]; then
echo "No python interpreter with wandb is available for sweep creation." >&2
exit 1
fi
if [ -z "$SWEEP_ID_VALUE" ]; then
if [ -z "${WANDB_API_KEY:-}" ]; then
export WANDB_API_KEY
WANDB_API_KEY="$($PY_SWEEP_BIN - <<'PY'
from dotenv import dotenv_values
print(dotenv_values('.env').get('WANDB_API_KEY', '').strip())
PY
)"
fi
if [ -z "${WANDB_API_KEY:-}" ]; then
echo "WANDB_API_KEY is required to create a sweep." >&2
exit 1
fi
SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \
--kind "$SWEEP_KIND" \
--project "$SWEEP_PROJECT" \
--entity "$SWEEP_ENTITY" \
--method "$SWEEP_METHOD" \
--run-cap "$SWEEP_RUN_CAP")"
fi
SWEEP_ENTRY_ARGS="${SWEEP_ENTRY_ARGS:-}"
if [ -z "$SWEEP_ENTRY_ARGS" ]; then
SWEEP_ENTRY_ARGS="--sweep-agent --sweep-id $SWEEP_ID_VALUE --project $SWEEP_PROJECT --device cpu"
fi
if [ "$AGENT_COUNT" = "0" ] && [ "${SWEEP_RUN_CAP:-0}" -gt 0 ]; then
TOTAL_AGENTS=$((SWEEP_NUM_NODES * AGENTS_PER_NODE))
if [ "$TOTAL_AGENTS" -gt 0 ]; then
AGENT_COUNT=$(((SWEEP_RUN_CAP + TOTAL_AGENTS - 1) / TOTAL_AGENTS))
echo "Derived AGENT_COUNT=$AGENT_COUNT from SWEEP_RUN_CAP=$SWEEP_RUN_CAP across $TOTAL_AGENTS agents"
fi
fi
DIST_ARGS=(
python
scripts/ray_distributed_train.py
--run-kind "$SWEEP_KIND"
--entry-args "$SWEEP_ENTRY_ARGS"
--num-nodes "${SWEEP_NUM_NODES}"
--tpu-per-task "${TPU_PER_TASK:-0}"
--base-seed "${BASE_SEED:-42}"
--wandb-entity "$SWEEP_ENTITY"
--wandb-project "$SWEEP_PROJECT"
--agents-per-node "$AGENTS_PER_NODE"
--agent-count "$AGENT_COUNT"
--inner-threads "$INNER_THREADS"
--worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}"
)
if [ "$SWEEP_KIND" = "benchmark" ]; then
DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}")
fi
if [ "${COMPARE_ROBUST:-0}" = "1" ]; then
DIST_ARGS+=(--compare-robust)
fi
echo "SWEEP_ID=$SWEEP_ID_VALUE"
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
exit 0
fi
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', 'benchmark', or 'sweep')." >&2
exit 1 exit 1

View File

@@ -97,6 +97,8 @@ while true; do
# Determine runtime version # Determine runtime version
RT_VERSION=${RUNTIME_VERSION:-"tpu-ubuntu2204-base"} RT_VERSION=${RUNTIME_VERSION:-"tpu-ubuntu2204-base"}
CREATE_LOG="/tmp/tpu_create_${QR_NAME}.log"
gcloud compute tpus queued-resources create $QR_NAME \ gcloud compute tpus queued-resources create $QR_NAME \
--project=$PROJECT_ID \ --project=$PROJECT_ID \
--node-id=$QR_NAME \ --node-id=$QR_NAME \
@@ -104,20 +106,23 @@ while true; do
--accelerator-type=$ACCEL_TYPE \ --accelerator-type=$ACCEL_TYPE \
--runtime-version=$RT_VERSION \ --runtime-version=$RT_VERSION \
$SPOT_FLAG \ $SPOT_FLAG \
--internal-ips \
--metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \ --metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \
--metadata "$METADATA" 2>&1 | tee /tmp/tpu_create_${QR_NAME}.log --metadata "$METADATA" 2>&1 | tee "$CREATE_LOG"
CREATE_EXIT=${PIPESTATUS[0]}
if [ $? -eq 0 ]; then if [ $CREATE_EXIT -eq 0 ]; then
echo "[$(date)] Successfully queued $QR_NAME." echo "[$(date)] Successfully queued $QR_NAME."
RETRY_DELAY=60 RETRY_DELAY=60
elif grep -q "IN_USE_ADDRESSES" /tmp/tpu_create_${QR_NAME}.log 2>/dev/null; then elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then
echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s" echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s"
sleep $RETRY_DELAY sleep $RETRY_DELAY
RETRY_DELAY=$((RETRY_DELAY * 2)) RETRY_DELAY=$((RETRY_DELAY * 2))
[ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY [ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY
continue continue
else else
echo "[$(date)] Failed to queue $QR_NAME." echo "[$(date)] Failed to queue $QR_NAME (exit=$CREATE_EXIT)."
RETRY_DELAY=60 RETRY_DELAY=60
fi fi
else else