mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: refactor for sweeps and IP configs
This commit is contained in:
@@ -10,9 +10,13 @@ services:
|
|||||||
- HF_TOKEN=${HF_TOKEN}
|
- HF_TOKEN=${HF_TOKEN}
|
||||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||||
- GITHUB_TOKEN=${GITHUB_TOKEN}
|
- GITHUB_TOKEN=${GITHUB_TOKEN}
|
||||||
|
- GOOGLE_APPLICATION_CREDENTIALS=/secrets/gcp-sa.json
|
||||||
|
- GCP_ACCOUNT=${GCP_ACCOUNT:-}
|
||||||
|
- WATCHDOG_CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-v6e_*.conf}
|
||||||
- CLOUDSDK_CONFIG=/.config/gcloud
|
- CLOUDSDK_CONFIG=/.config/gcloud
|
||||||
volumes:
|
volumes:
|
||||||
- ~/.config/gcloud:/.config/gcloud:rw
|
- ~/.config/gcloud:/.config/gcloud:rw
|
||||||
|
- ./secrets/gcp-sa.json:/secrets/gcp-sa.json:ro
|
||||||
|
|
||||||
tensorboard-rl:
|
tensorboard-rl:
|
||||||
image: tensorflow/tensorflow:latest
|
image: tensorflow/tensorflow:latest
|
||||||
|
|||||||
@@ -36,24 +36,54 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT
|
|||||||
echo "Authenticating gcloud using service account key..."
|
echo "Authenticating gcloud using service account key..."
|
||||||
gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS"
|
gcloud auth activate-service-account --key-file="$GOOGLE_APPLICATION_CREDENTIALS"
|
||||||
|
|
||||||
# Extract project ID from the key file
|
if [ -z "$PROJECT_ID" ]; then
|
||||||
PROJECT_ID=$(jq -r '.project_id' "$GOOGLE_APPLICATION_CREDENTIALS")
|
PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||||
if [ -n "$PROJECT_ID" ] && [ "$PROJECT_ID" != "null" ]; then
|
|
||||||
gcloud config set project "$PROJECT_ID"
|
|
||||||
echo "Set project to $PROJECT_ID"
|
|
||||||
fi
|
fi
|
||||||
|
elif [ "$CRED_TYPE" = "authorized_user" ]; then
|
||||||
|
echo "Authenticating gcloud using authorized_user refresh token..."
|
||||||
|
|
||||||
|
AUTH_ACCOUNT="$GCP_ACCOUNT"
|
||||||
|
if [ -z "$AUTH_ACCOUNT" ]; then
|
||||||
|
AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||||
|
fi
|
||||||
|
if [ -z "$AUTH_ACCOUNT" ]; then
|
||||||
|
AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
|
||||||
|
fi
|
||||||
|
|
||||||
|
REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||||
|
if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then
|
||||||
|
echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN"
|
||||||
else
|
else
|
||||||
echo "Note: Using application default credentials or mounted gcloud config..."
|
echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config."
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
echo "Note: Assuming gcloud config is mounted from host."
|
echo "Note: Assuming gcloud config is mounted from host."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -n "$PROJECT_ID" ]; then
|
||||||
|
gcloud config set project "$PROJECT_ID"
|
||||||
|
echo "Set project to $PROJECT_ID"
|
||||||
|
fi
|
||||||
|
|
||||||
# Run the watchdogs in the background using bash instead of tmux
|
# Run the watchdogs in the background using bash instead of tmux
|
||||||
# Tmux needs a TTY to attach properly which we might not have in docker
|
# Tmux needs a TTY to attach properly which we might not have in docker
|
||||||
# Stagger startups by 15s to prevent simultaneous TPU creation quota hits
|
# Stagger startups by 15s to prevent simultaneous TPU creation quota hits
|
||||||
|
CONFIG_PATTERN=${WATCHDOG_CONFIG_PATTERN:-"*.conf"}
|
||||||
|
shopt -s nullglob
|
||||||
|
CONFIGS=(/app/tpu_orchestration/configs/$CONFIG_PATTERN)
|
||||||
|
|
||||||
|
if [ ${#CONFIGS[@]} -eq 0 ]; then
|
||||||
|
echo "Error: no watchdog configs matched pattern '$CONFIG_PATTERN'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Using watchdog config pattern: $CONFIG_PATTERN"
|
||||||
DELAY=0
|
DELAY=0
|
||||||
for conf in /app/tpu_orchestration/configs/*.conf; do
|
for conf in "${CONFIGS[@]}"; do
|
||||||
echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)"
|
echo "Starting watchdog for $(basename "$conf" .conf) (delay: ${DELAY}s)"
|
||||||
(sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") &
|
(sleep $DELAY && /app/tpu_orchestration/watchdog.sh "$conf") &
|
||||||
DELAY=$((DELAY + 15))
|
DELAY=$((DELAY + 15))
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import concurrent.futures
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||||
|
|
||||||
|
|
||||||
def _has_flag(tokens: list[str], name: str) -> bool:
|
def _has_flag(tokens: list[str], name: str) -> bool:
|
||||||
@@ -24,18 +28,301 @@ def _entry_tokens(run_kind: str, entry_args: str) -> list[str]:
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def _alive_node_ips() -> list[str]:
|
def _get_flag_value(tokens: list[str], name: str, default: str = "") -> str:
|
||||||
|
for idx, tok in enumerate(tokens):
|
||||||
|
if tok == name and idx + 1 < len(tokens):
|
||||||
|
return str(tokens[idx + 1])
|
||||||
|
if tok.startswith(f"{name}="):
|
||||||
|
return str(tok.split("=", 1)[1])
|
||||||
|
return str(default)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_flag_value(tokens: list[str], name: str, value: str) -> list[str]:
|
||||||
|
updated: list[str] = []
|
||||||
|
replaced = False
|
||||||
|
idx = 0
|
||||||
|
while idx < len(tokens):
|
||||||
|
tok = tokens[idx]
|
||||||
|
if tok == name:
|
||||||
|
replaced = True
|
||||||
|
updated.extend([name, str(value)])
|
||||||
|
idx += 2
|
||||||
|
continue
|
||||||
|
if tok.startswith(f"{name}="):
|
||||||
|
replaced = True
|
||||||
|
updated.append(f"{name}={value}")
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
updated.append(tok)
|
||||||
|
idx += 1
|
||||||
|
if not replaced:
|
||||||
|
updated.extend([name, str(value)])
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_flag(tokens: list[str], name: str) -> list[str]:
|
||||||
|
updated: list[str] = []
|
||||||
|
idx = 0
|
||||||
|
while idx < len(tokens):
|
||||||
|
tok = tokens[idx]
|
||||||
|
if tok == name:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
if tok.startswith(f"{name}="):
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
updated.append(tok)
|
||||||
|
idx += 1
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
def _csv_values(raw: str) -> list[str]:
|
||||||
|
return [piece.strip() for piece in str(raw).split(",") if piece.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def _alpha_token(alpha: str) -> str:
|
||||||
|
return str(alpha).replace(".", "p").replace("-", "m")
|
||||||
|
|
||||||
|
|
||||||
|
def _truthy(value: str | bool | None) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
|
def _alive_nodes() -> list[tuple[str, str]]:
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
ips: list[str] = []
|
nodes: list[tuple[str, str]] = []
|
||||||
for node in ray.nodes():
|
for node in ray.nodes():
|
||||||
if not bool(node.get("Alive", False)):
|
if not bool(node.get("Alive", False)):
|
||||||
continue
|
continue
|
||||||
|
node_id = str(node.get("NodeID", "")).strip()
|
||||||
ip = str(node.get("NodeManagerAddress", "")).strip()
|
ip = str(node.get("NodeManagerAddress", "")).strip()
|
||||||
if not ip or ip in seen:
|
if not node_id or not ip or node_id in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(ip)
|
seen.add(node_id)
|
||||||
ips.append(ip)
|
nodes.append((node_id, ip))
|
||||||
return sorted(ips)
|
return sorted(nodes, key=lambda item: (item[1], item[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_cells(
|
||||||
|
tokens: list[str], *, compare_robust: bool
|
||||||
|
) -> list[tuple[str, str, str, bool]]:
|
||||||
|
tiers = _csv_values(
|
||||||
|
_get_flag_value(tokens, "--tiers", "static,surge,linear,qtable,ppo")
|
||||||
|
)
|
||||||
|
alphas = _csv_values(_get_flag_value(tokens, "--alpha-values", "0.0,0.3,0.6"))
|
||||||
|
base_no_robust = _has_flag(tokens, "--no-robust")
|
||||||
|
if compare_robust:
|
||||||
|
modes = [("robust", False), ("no_robust", True)]
|
||||||
|
else:
|
||||||
|
modes = [("no_robust", True)] if base_no_robust else [("robust", False)]
|
||||||
|
return [
|
||||||
|
(tier, alpha, mode_label, no_robust)
|
||||||
|
for tier in tiers
|
||||||
|
for alpha in alphas
|
||||||
|
for mode_label, no_robust in modes
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _thread_limited_env(env: dict[str, str], threads: int) -> dict[str, str]:
|
||||||
|
bounded = dict(env)
|
||||||
|
n = str(max(1, int(threads)))
|
||||||
|
for key in (
|
||||||
|
"OMP_NUM_THREADS",
|
||||||
|
"MKL_NUM_THREADS",
|
||||||
|
"OPENBLAS_NUM_THREADS",
|
||||||
|
"NUMEXPR_NUM_THREADS",
|
||||||
|
"VECLIB_MAXIMUM_THREADS",
|
||||||
|
"BLIS_NUM_THREADS",
|
||||||
|
):
|
||||||
|
bounded[key] = n
|
||||||
|
return bounded
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _semaphore_guard(semaphore: threading.Semaphore | None):
|
||||||
|
if semaphore is None:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
semaphore.acquire()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_benchmark_cells_parallel(
|
||||||
|
*,
|
||||||
|
root: str,
|
||||||
|
env: dict[str, str],
|
||||||
|
base_tokens: list[str],
|
||||||
|
compare_robust: bool,
|
||||||
|
inner_workers: int,
|
||||||
|
inner_threads: int,
|
||||||
|
max_heavy_workers: int,
|
||||||
|
rank: int,
|
||||||
|
) -> int:
|
||||||
|
cells = _benchmark_cells(base_tokens, compare_robust=compare_robust)
|
||||||
|
if not cells:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cwd = str(Path(root))
|
||||||
|
base_out = _get_flag_value(base_tokens, "--output-dir", "engine/studies/results")
|
||||||
|
max_workers = max(1, min(int(inner_workers), len(cells)))
|
||||||
|
heavy_tiers = {"ppo", "a2c", "dqn"}
|
||||||
|
heavy_limit = max(1, int(max_heavy_workers))
|
||||||
|
heavy_sem = threading.Semaphore(heavy_limit)
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"rank": int(rank),
|
||||||
|
"benchmark_cells": len(cells),
|
||||||
|
"inner_workers": int(max_workers),
|
||||||
|
"inner_threads": int(max(1, int(inner_threads))),
|
||||||
|
"heavy_limit": int(heavy_limit),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_cell(
|
||||||
|
index: int,
|
||||||
|
total: int,
|
||||||
|
tier: str,
|
||||||
|
alpha: str,
|
||||||
|
mode_label: str,
|
||||||
|
no_robust: bool,
|
||||||
|
) -> tuple[str, str, str, int]:
|
||||||
|
tokens = list(base_tokens)
|
||||||
|
tokens = _set_flag_value(tokens, "--tiers", tier)
|
||||||
|
tokens = _set_flag_value(tokens, "--alpha-values", alpha)
|
||||||
|
if no_robust:
|
||||||
|
if not _has_flag(tokens, "--no-robust"):
|
||||||
|
tokens.append("--no-robust")
|
||||||
|
else:
|
||||||
|
tokens = _remove_flag(tokens, "--no-robust")
|
||||||
|
|
||||||
|
cell_out = (
|
||||||
|
Path(base_out)
|
||||||
|
/ f"tier_{tier}"
|
||||||
|
/ f"mode_{mode_label}"
|
||||||
|
/ f"alpha_{_alpha_token(alpha)}"
|
||||||
|
)
|
||||||
|
tokens = _set_flag_value(tokens, "--output-dir", str(cell_out))
|
||||||
|
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||||
|
cell_env = _thread_limited_env(env, int(inner_threads))
|
||||||
|
cell_env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "0"
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"rank": int(rank),
|
||||||
|
"cell": f"{index}/{total}",
|
||||||
|
"tier": tier,
|
||||||
|
"mode": mode_label,
|
||||||
|
"alpha": alpha,
|
||||||
|
"command": " ".join(cmd),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
heavy_guard = heavy_sem if str(tier).lower() in heavy_tiers else None
|
||||||
|
with _semaphore_guard(heavy_guard):
|
||||||
|
proc = subprocess.run(cmd, cwd=cwd, env=cell_env)
|
||||||
|
return tier, alpha, mode_label, int(proc.returncode)
|
||||||
|
|
||||||
|
failures: list[tuple[str, str, str, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(_run_cell, idx, len(cells), tier, alpha, mode_label, no_robust)
|
||||||
|
for idx, (tier, alpha, mode_label, no_robust) in enumerate(cells, start=1)
|
||||||
|
]
|
||||||
|
for fut in concurrent.futures.as_completed(futures):
|
||||||
|
tier, alpha, mode_label, code = fut.result()
|
||||||
|
if code != 0:
|
||||||
|
failures.append((tier, alpha, mode_label, code))
|
||||||
|
|
||||||
|
if failures:
|
||||||
|
print({"rank": int(rank), "benchmark_failures": failures})
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _run_sweep_agents_parallel(
|
||||||
|
*,
|
||||||
|
root: str,
|
||||||
|
env: dict[str, str],
|
||||||
|
base_tokens: list[str],
|
||||||
|
run_kind: str,
|
||||||
|
rank: int,
|
||||||
|
agents_per_node: int,
|
||||||
|
agent_count: int,
|
||||||
|
inner_threads: int,
|
||||||
|
tpu_agent_slots: int,
|
||||||
|
) -> int:
|
||||||
|
total = max(1, int(agents_per_node))
|
||||||
|
cwd = str(Path(root))
|
||||||
|
wants_tpu = str(env.get("JAX_PLATFORMS", "")).strip().lower() == "tpu"
|
||||||
|
tpu_slots = max(0, int(tpu_agent_slots))
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"rank": int(rank),
|
||||||
|
"sweep_agents": int(total),
|
||||||
|
"agent_count": int(agent_count),
|
||||||
|
"inner_threads": int(max(1, int(inner_threads))),
|
||||||
|
"jax_platform": str(env.get("JAX_PLATFORMS", "")),
|
||||||
|
"tpu_agent_slots": int(tpu_slots),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_agent(slot: int) -> int:
|
||||||
|
tokens = list(base_tokens)
|
||||||
|
if int(agent_count) > 0 and not _has_flag(tokens, "--count"):
|
||||||
|
tokens.extend(["--count", str(int(agent_count))])
|
||||||
|
|
||||||
|
if _has_flag(tokens, "--group"):
|
||||||
|
base_group = _get_flag_value(tokens, "--group", "ray-sweep")
|
||||||
|
tokens = _set_flag_value(tokens, "--group", f"{base_group}-a{slot}")
|
||||||
|
|
||||||
|
if run_kind == "benchmark":
|
||||||
|
out_dir = _get_flag_value(tokens, "--output-dir", "engine/studies/results")
|
||||||
|
tokens = _set_flag_value(
|
||||||
|
tokens, "--output-dir", str(Path(out_dir) / f"agent_{slot}")
|
||||||
|
)
|
||||||
|
if run_kind == "train":
|
||||||
|
model_dir = _get_flag_value(tokens, "--model-dir", "engine/models")
|
||||||
|
tokens = _set_flag_value(
|
||||||
|
tokens, "--model-dir", str(Path(model_dir) / f"agent_{slot}")
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||||
|
agent_env = _thread_limited_env(env, int(inner_threads))
|
||||||
|
if wants_tpu and tpu_slots > 0 and int(slot) > tpu_slots:
|
||||||
|
agent_env["JAX_PLATFORMS"] = "cpu"
|
||||||
|
agent_env["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
|
agent_env["PHANTOM_SWEEP_AGENT_SLOT"] = str(int(slot))
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"rank": int(rank),
|
||||||
|
"agent_slot": int(slot),
|
||||||
|
"jax_platform": str(agent_env.get("JAX_PLATFORMS", "")),
|
||||||
|
"command": " ".join(cmd),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
proc = subprocess.run(cmd, cwd=cwd, env=agent_env)
|
||||||
|
return int(proc.returncode)
|
||||||
|
|
||||||
|
failures: list[tuple[int, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=total) as pool:
|
||||||
|
future_map = {
|
||||||
|
pool.submit(_run_agent, slot): slot for slot in range(1, total + 1)
|
||||||
|
}
|
||||||
|
for future in concurrent.futures.as_completed(future_map):
|
||||||
|
slot = int(future_map[future])
|
||||||
|
code = int(future.result())
|
||||||
|
if code != 0:
|
||||||
|
failures.append((slot, code))
|
||||||
|
|
||||||
|
if failures:
|
||||||
|
print({"rank": int(rank), "sweep_failures": failures})
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(max_retries=0)
|
@ray.remote(max_retries=0)
|
||||||
@@ -44,6 +331,8 @@ def _train_on_node(
|
|||||||
root: str,
|
root: str,
|
||||||
run_kind: str,
|
run_kind: str,
|
||||||
entry_args: str,
|
entry_args: str,
|
||||||
|
node_id: str,
|
||||||
|
node_ip: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
coordinator_ip: str,
|
coordinator_ip: str,
|
||||||
@@ -54,17 +343,32 @@ def _train_on_node(
|
|||||||
output_root: str,
|
output_root: str,
|
||||||
wandb_entity: str,
|
wandb_entity: str,
|
||||||
wandb_project: str,
|
wandb_project: str,
|
||||||
|
agents_per_node: int,
|
||||||
|
agent_count: int,
|
||||||
|
inner_workers: int,
|
||||||
|
inner_threads: int,
|
||||||
|
max_heavy_workers: int,
|
||||||
sync_jax: bool,
|
sync_jax: bool,
|
||||||
) -> int:
|
) -> int:
|
||||||
env = dict(os.environ)
|
env = dict(os.environ)
|
||||||
env["PYTHONUNBUFFERED"] = "1"
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
|
requested_platform = str(env.get("PHANTOM_JAX_PLATFORM", "tpu")).strip().lower()
|
||||||
if world_size > 1 and requested_platform == "tpu":
|
allow_multi_node_tpu = _truthy(env.get("PHANTOM_ALLOW_MULTI_NODE_TPU"))
|
||||||
|
if world_size > 1 and requested_platform == "tpu" and not allow_multi_node_tpu:
|
||||||
requested_platform = "cpu"
|
requested_platform = "cpu"
|
||||||
print(
|
print(
|
||||||
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs "
|
"PHANTOM_DISTRIBUTED_NOTE: forcing JAX_PLATFORMS=cpu for multi-node SB3 runs "
|
||||||
|
"(set PHANTOM_ALLOW_MULTI_NODE_TPU=1 to keep TPU for JAX workloads)"
|
||||||
|
)
|
||||||
|
elif world_size > 1 and requested_platform == "tpu" and allow_multi_node_tpu:
|
||||||
|
print(
|
||||||
|
"PHANTOM_DISTRIBUTED_NOTE: keeping JAX_PLATFORMS=tpu in multi-node mixed mode"
|
||||||
)
|
)
|
||||||
env["JAX_PLATFORMS"] = requested_platform
|
env["JAX_PLATFORMS"] = requested_platform
|
||||||
|
if requested_platform == "cpu":
|
||||||
|
env["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
|
else:
|
||||||
|
env.pop("JAX_PLATFORM_NAME", None)
|
||||||
# Keep each train process in single-host mode to avoid accidental global stalls.
|
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||||
if run_kind == "benchmark":
|
if run_kind == "benchmark":
|
||||||
@@ -96,13 +400,29 @@ def _train_on_node(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tokens = _entry_tokens(run_kind, entry_args)
|
tokens = _entry_tokens(run_kind, entry_args)
|
||||||
|
is_sweep_agent = _has_flag(tokens, "--sweep-agent")
|
||||||
seed = int(base_seed + rank)
|
seed = int(base_seed + rank)
|
||||||
if not _has_flag(tokens, "--seed"):
|
if not is_sweep_agent and not _has_flag(tokens, "--seed"):
|
||||||
tokens.extend(["--seed", str(seed)])
|
tokens.extend(["--seed", str(seed)])
|
||||||
|
|
||||||
if run_kind == "train" and not _has_flag(tokens, "--group"):
|
if run_kind == "train" and not _has_flag(tokens, "--group"):
|
||||||
tokens.extend(["--group", run_group])
|
tokens.extend(["--group", run_group])
|
||||||
|
|
||||||
|
if is_sweep_agent and int(agent_count) > 0 and not _has_flag(tokens, "--count"):
|
||||||
|
tokens.extend(["--count", str(int(agent_count))])
|
||||||
|
|
||||||
|
try:
|
||||||
|
tpu_agent_slots = int(
|
||||||
|
str(
|
||||||
|
env.get(
|
||||||
|
"PHANTOM_TPU_AGENT_SLOTS",
|
||||||
|
"1" if requested_platform == "tpu" else "0",
|
||||||
|
)
|
||||||
|
).strip()
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
tpu_agent_slots = 1 if requested_platform == "tpu" else 0
|
||||||
|
|
||||||
if (
|
if (
|
||||||
run_kind == "benchmark"
|
run_kind == "benchmark"
|
||||||
and output_root
|
and output_root
|
||||||
@@ -112,9 +432,36 @@ def _train_on_node(
|
|||||||
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
tokens.extend(["--output-dir", str(out_dir)])
|
tokens.extend(["--output-dir", str(out_dir)])
|
||||||
|
|
||||||
|
if is_sweep_agent and int(agents_per_node) > 1:
|
||||||
|
return _run_sweep_agents_parallel(
|
||||||
|
root=root,
|
||||||
|
env=env,
|
||||||
|
base_tokens=tokens,
|
||||||
|
run_kind=run_kind,
|
||||||
|
rank=rank,
|
||||||
|
agents_per_node=int(agents_per_node),
|
||||||
|
agent_count=int(agent_count),
|
||||||
|
inner_threads=int(inner_threads),
|
||||||
|
tpu_agent_slots=int(max(0, tpu_agent_slots)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_kind == "benchmark" and int(inner_workers) > 1 and not is_sweep_agent:
|
||||||
|
return _run_benchmark_cells_parallel(
|
||||||
|
root=root,
|
||||||
|
env=env,
|
||||||
|
base_tokens=tokens,
|
||||||
|
compare_robust=bool(compare_robust),
|
||||||
|
inner_workers=int(inner_workers),
|
||||||
|
inner_threads=int(inner_threads),
|
||||||
|
max_heavy_workers=int(max_heavy_workers),
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||||
print(
|
print(
|
||||||
{
|
{
|
||||||
|
"node_id": node_id,
|
||||||
|
"node_ip": node_ip,
|
||||||
"rank": int(rank),
|
"rank": int(rank),
|
||||||
"run_kind": run_kind,
|
"run_kind": run_kind,
|
||||||
"seed": int(seed),
|
"seed": int(seed),
|
||||||
@@ -124,7 +471,9 @@ def _train_on_node(
|
|||||||
"command": " ".join(cmd),
|
"command": " ".join(cmd),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
proc = subprocess.run(cmd, cwd=cwd, env=env)
|
proc = subprocess.run(
|
||||||
|
cmd, cwd=cwd, env=_thread_limited_env(env, int(inner_threads))
|
||||||
|
)
|
||||||
return int(proc.returncode)
|
return int(proc.returncode)
|
||||||
|
|
||||||
|
|
||||||
@@ -145,6 +494,12 @@ def main() -> None:
|
|||||||
parser.add_argument("--output-root", type=str, default="")
|
parser.add_argument("--output-root", type=str, default="")
|
||||||
parser.add_argument("--wandb-entity", type=str, default="")
|
parser.add_argument("--wandb-entity", type=str, default="")
|
||||||
parser.add_argument("--wandb-project", type=str, default="")
|
parser.add_argument("--wandb-project", type=str, default="")
|
||||||
|
parser.add_argument("--agents-per-node", type=int, default=1)
|
||||||
|
parser.add_argument("--agent-count", type=int, default=0)
|
||||||
|
parser.add_argument("--inner-workers", type=int, default=1)
|
||||||
|
parser.add_argument("--inner-threads", type=int, default=1)
|
||||||
|
parser.add_argument("--max-heavy-workers", type=int, default=2)
|
||||||
|
parser.add_argument("--worker-cpus", type=float, default=1.0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
entry_args = str(args.entry_args or args.train_args).strip()
|
entry_args = str(args.entry_args or args.train_args).strip()
|
||||||
@@ -153,21 +508,24 @@ def main() -> None:
|
|||||||
|
|
||||||
ray.init(address="auto")
|
ray.init(address="auto")
|
||||||
|
|
||||||
node_ips = _alive_node_ips()
|
node_entries = _alive_nodes()
|
||||||
if not node_ips:
|
if not node_entries:
|
||||||
raise RuntimeError("No alive Ray nodes found")
|
raise RuntimeError("No alive Ray nodes found")
|
||||||
|
|
||||||
requested = int(args.num_nodes)
|
requested = int(args.num_nodes)
|
||||||
if requested > 0:
|
if requested > 0:
|
||||||
node_ips = node_ips[:requested]
|
node_entries = node_entries[:requested]
|
||||||
|
|
||||||
world_size = len(node_ips)
|
world_size = len(node_entries)
|
||||||
coordinator_ip = node_ips[0]
|
coordinator_ip = node_entries[0][1]
|
||||||
run_group = args.run_group or f"ray-dist-{int(time.time())}"
|
run_group = args.run_group or f"ray-dist-{int(time.time())}"
|
||||||
|
|
||||||
print(
|
print(
|
||||||
{
|
{
|
||||||
"nodes": node_ips,
|
"nodes": [
|
||||||
|
{"node_id": node_id, "node_ip": node_ip}
|
||||||
|
for node_id, node_ip in node_entries
|
||||||
|
],
|
||||||
"world_size": world_size,
|
"world_size": world_size,
|
||||||
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
|
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
|
||||||
"run_kind": str(args.run_kind),
|
"run_kind": str(args.run_kind),
|
||||||
@@ -175,18 +533,35 @@ def main() -> None:
|
|||||||
"run_group": run_group,
|
"run_group": run_group,
|
||||||
"compare_robust": bool(args.compare_robust),
|
"compare_robust": bool(args.compare_robust),
|
||||||
"output_root": str(args.output_root),
|
"output_root": str(args.output_root),
|
||||||
|
"agents_per_node": int(args.agents_per_node),
|
||||||
|
"agent_count": int(args.agent_count),
|
||||||
|
"inner_workers": int(args.inner_workers),
|
||||||
|
"inner_threads": int(args.inner_threads),
|
||||||
|
"max_heavy_workers": int(args.max_heavy_workers),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
root = str(Path(__file__).resolve().parents[1])
|
root = str(Path(__file__).resolve().parents[1])
|
||||||
for rank, node_ip in enumerate(node_ips):
|
for rank, (node_id, node_ip) in enumerate(node_entries):
|
||||||
resources = {f"node:{node_ip}": 0.01, "TPU": float(args.tpu_per_task)}
|
resources: dict[str, float] = {}
|
||||||
|
tpu_per_task = float(args.tpu_per_task)
|
||||||
|
if tpu_per_task > 0.0:
|
||||||
|
resources["TPU"] = tpu_per_task
|
||||||
futures.append(
|
futures.append(
|
||||||
_train_on_node.options(resources=resources).remote(
|
_train_on_node.options(
|
||||||
|
resources=resources,
|
||||||
|
num_cpus=float(args.worker_cpus),
|
||||||
|
scheduling_strategy=NodeAffinitySchedulingStrategy(
|
||||||
|
node_id=node_id,
|
||||||
|
soft=False,
|
||||||
|
),
|
||||||
|
).remote(
|
||||||
root=root,
|
root=root,
|
||||||
run_kind=str(args.run_kind),
|
run_kind=str(args.run_kind),
|
||||||
entry_args=entry_args,
|
entry_args=entry_args,
|
||||||
|
node_id=node_id,
|
||||||
|
node_ip=node_ip,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
coordinator_ip=coordinator_ip,
|
coordinator_ip=coordinator_ip,
|
||||||
@@ -197,6 +572,11 @@ def main() -> None:
|
|||||||
output_root=str(args.output_root),
|
output_root=str(args.output_root),
|
||||||
wandb_entity=str(args.wandb_entity),
|
wandb_entity=str(args.wandb_entity),
|
||||||
wandb_project=str(args.wandb_project),
|
wandb_project=str(args.wandb_project),
|
||||||
|
agents_per_node=int(args.agents_per_node),
|
||||||
|
agent_count=int(args.agent_count),
|
||||||
|
inner_workers=int(args.inner_workers),
|
||||||
|
inner_threads=int(args.inner_threads),
|
||||||
|
max_heavy_workers=int(args.max_heavy_workers),
|
||||||
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
|
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
# RAY_MODE=single -> one run (default)
|
# RAY_MODE=single -> one run (default)
|
||||||
# RAY_MODE=distributed -> one run per TPU node (experimental)
|
# RAY_MODE=distributed -> one run per TPU node (experimental)
|
||||||
# RAY_MODE=benchmark -> one benchmark run per TPU node (overnight)
|
# RAY_MODE=benchmark -> one benchmark run per TPU node (overnight)
|
||||||
|
# RAY_MODE=sweep -> distributed W&B sweep agents
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
@@ -28,7 +29,14 @@ env = dotenv_values(".env")
|
|||||||
# Filter out empty/None values
|
# Filter out empty/None values
|
||||||
env_vars = {k: v for k, v in env.items() if v}
|
env_vars = {k: v for k, v in env.items() if v}
|
||||||
env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0"))
|
env_vars.setdefault("CLOUD_TPU_TASK_ID", os.getenv("CLOUD_TPU_TASK_ID", "0"))
|
||||||
for k in ("WANDB_ENTITY", "WANDB_PROJECT", "PHANTOM_BENCHMARK_COMPARE_ROBUST"):
|
for k in (
|
||||||
|
"WANDB_ENTITY",
|
||||||
|
"WANDB_PROJECT",
|
||||||
|
"PHANTOM_BENCHMARK_COMPARE_ROBUST",
|
||||||
|
"PHANTOM_JAX_PLATFORM",
|
||||||
|
"PHANTOM_ALLOW_MULTI_NODE_TPU",
|
||||||
|
"PHANTOM_TPU_AGENT_SLOTS",
|
||||||
|
):
|
||||||
if os.getenv(k):
|
if os.getenv(k):
|
||||||
env_vars[k] = os.getenv(k)
|
env_vars[k] = os.getenv(k)
|
||||||
|
|
||||||
@@ -52,6 +60,15 @@ print(json.dumps({
|
|||||||
RAY_MODE="${RAY_MODE:-single}"
|
RAY_MODE="${RAY_MODE:-single}"
|
||||||
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}"
|
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --total-timesteps 1000000}"
|
||||||
BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}"
|
BENCHMARK_ARGS="${BENCHMARK_ARGS:---project capstone_tpu --tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.1,0.25,0.4,0.6,0.8 --episodes 12 --total-timesteps 30000 --max-steps 100 --robust-radius 0.2 --robust-points 7 --robust-rollouts 1 --lambda-coi 0.2 --eta-ux 0.5 --reward-profit-weight 1.0 --device cpu}"
|
||||||
|
INNER_WORKERS="${INNER_WORKERS:-16}"
|
||||||
|
INNER_THREADS="${INNER_THREADS:-1}"
|
||||||
|
MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}"
|
||||||
|
WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}"
|
||||||
|
SWEEP_KIND="${SWEEP_KIND:-benchmark}"
|
||||||
|
SWEEP_METHOD="${SWEEP_METHOD:-random}"
|
||||||
|
SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}"
|
||||||
|
AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}"
|
||||||
|
AGENT_COUNT="${AGENT_COUNT:-0}"
|
||||||
|
|
||||||
SUBMIT_ARGS=()
|
SUBMIT_ARGS=()
|
||||||
if [ "${RAY_NO_WAIT:-0}" = "1" ]; then
|
if [ "${RAY_NO_WAIT:-0}" = "1" ]; then
|
||||||
@@ -104,6 +121,10 @@ if [ "$RAY_MODE" = "benchmark" ]; then
|
|||||||
--output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}"
|
--output-root "${OUTPUT_ROOT:-engine/studies/results/overnight}"
|
||||||
--wandb-entity "${WANDB_ENTITY:-lusiana}"
|
--wandb-entity "${WANDB_ENTITY:-lusiana}"
|
||||||
--wandb-project "${WANDB_PROJECT:-capstone_tpu}"
|
--wandb-project "${WANDB_PROJECT:-capstone_tpu}"
|
||||||
|
--inner-workers "${INNER_WORKERS}"
|
||||||
|
--inner-threads "${INNER_THREADS}"
|
||||||
|
--max-heavy-workers "${MAX_HEAVY_WORKERS}"
|
||||||
|
--worker-cpus "${WORKER_CPUS}"
|
||||||
)
|
)
|
||||||
if [ "${COMPARE_ROBUST:-1}" = "1" ]; then
|
if [ "${COMPARE_ROBUST:-1}" = "1" ]; then
|
||||||
DIST_ARGS+=(--compare-robust)
|
DIST_ARGS+=(--compare-robust)
|
||||||
@@ -112,5 +133,97 @@ if [ "$RAY_MODE" = "benchmark" ]; then
|
|||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', or 'benchmark')." >&2
|
if [ "$RAY_MODE" = "sweep" ]; then
|
||||||
|
SWEEP_PROJECT="${WANDB_PROJECT:-capstone_tpu}"
|
||||||
|
SWEEP_ENTITY="${WANDB_ENTITY:-lusiana}"
|
||||||
|
SWEEP_ID_VALUE="${SWEEP_ID:-}"
|
||||||
|
SWEEP_NUM_NODES="${NUM_NODES:-5}"
|
||||||
|
PY_SWEEP_BIN="${PY_SWEEP_BIN:-}"
|
||||||
|
if [ -z "$PY_SWEEP_BIN" ]; then
|
||||||
|
for cand in "$ROOT/.venv/bin/python" "$ROOT/.venv-ray/bin/python" python3 python; do
|
||||||
|
if [ "$cand" = "python3" ] || [ "$cand" = "python" ]; then
|
||||||
|
command -v "$cand" >/dev/null 2>&1 || continue
|
||||||
|
elif [ ! -x "$cand" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
if "$cand" - <<'PY' >/dev/null 2>&1
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
cwd = str(Path.cwd())
|
||||||
|
sys.path = [p for p in sys.path if p not in {'', cwd}]
|
||||||
|
import wandb
|
||||||
|
print(wandb.__name__)
|
||||||
|
PY
|
||||||
|
then
|
||||||
|
PY_SWEEP_BIN="$cand"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
if [ -z "$PY_SWEEP_BIN" ]; then
|
||||||
|
echo "No python interpreter with wandb is available for sweep creation." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$SWEEP_ID_VALUE" ]; then
|
||||||
|
if [ -z "${WANDB_API_KEY:-}" ]; then
|
||||||
|
export WANDB_API_KEY
|
||||||
|
WANDB_API_KEY="$($PY_SWEEP_BIN - <<'PY'
|
||||||
|
from dotenv import dotenv_values
|
||||||
|
print(dotenv_values('.env').get('WANDB_API_KEY', '').strip())
|
||||||
|
PY
|
||||||
|
)"
|
||||||
|
fi
|
||||||
|
if [ -z "${WANDB_API_KEY:-}" ]; then
|
||||||
|
echo "WANDB_API_KEY is required to create a sweep." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \
|
||||||
|
--kind "$SWEEP_KIND" \
|
||||||
|
--project "$SWEEP_PROJECT" \
|
||||||
|
--entity "$SWEEP_ENTITY" \
|
||||||
|
--method "$SWEEP_METHOD" \
|
||||||
|
--run-cap "$SWEEP_RUN_CAP")"
|
||||||
|
fi
|
||||||
|
|
||||||
|
SWEEP_ENTRY_ARGS="${SWEEP_ENTRY_ARGS:-}"
|
||||||
|
if [ -z "$SWEEP_ENTRY_ARGS" ]; then
|
||||||
|
SWEEP_ENTRY_ARGS="--sweep-agent --sweep-id $SWEEP_ID_VALUE --project $SWEEP_PROJECT --device cpu"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$AGENT_COUNT" = "0" ] && [ "${SWEEP_RUN_CAP:-0}" -gt 0 ]; then
|
||||||
|
TOTAL_AGENTS=$((SWEEP_NUM_NODES * AGENTS_PER_NODE))
|
||||||
|
if [ "$TOTAL_AGENTS" -gt 0 ]; then
|
||||||
|
AGENT_COUNT=$(((SWEEP_RUN_CAP + TOTAL_AGENTS - 1) / TOTAL_AGENTS))
|
||||||
|
echo "Derived AGENT_COUNT=$AGENT_COUNT from SWEEP_RUN_CAP=$SWEEP_RUN_CAP across $TOTAL_AGENTS agents"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
DIST_ARGS=(
|
||||||
|
python
|
||||||
|
scripts/ray_distributed_train.py
|
||||||
|
--run-kind "$SWEEP_KIND"
|
||||||
|
--entry-args "$SWEEP_ENTRY_ARGS"
|
||||||
|
--num-nodes "${SWEEP_NUM_NODES}"
|
||||||
|
--tpu-per-task "${TPU_PER_TASK:-0}"
|
||||||
|
--base-seed "${BASE_SEED:-42}"
|
||||||
|
--wandb-entity "$SWEEP_ENTITY"
|
||||||
|
--wandb-project "$SWEEP_PROJECT"
|
||||||
|
--agents-per-node "$AGENTS_PER_NODE"
|
||||||
|
--agent-count "$AGENT_COUNT"
|
||||||
|
--inner-threads "$INNER_THREADS"
|
||||||
|
--worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}"
|
||||||
|
)
|
||||||
|
if [ "$SWEEP_KIND" = "benchmark" ]; then
|
||||||
|
DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}")
|
||||||
|
fi
|
||||||
|
if [ "${COMPARE_ROBUST:-0}" = "1" ]; then
|
||||||
|
DIST_ARGS+=(--compare-robust)
|
||||||
|
fi
|
||||||
|
echo "SWEEP_ID=$SWEEP_ID_VALUE"
|
||||||
|
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Unsupported RAY_MODE='$RAY_MODE' (expected 'single', 'distributed', 'benchmark', or 'sweep')." >&2
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
@@ -97,6 +97,8 @@ while true; do
|
|||||||
# Determine runtime version
|
# Determine runtime version
|
||||||
RT_VERSION=${RUNTIME_VERSION:-"tpu-ubuntu2204-base"}
|
RT_VERSION=${RUNTIME_VERSION:-"tpu-ubuntu2204-base"}
|
||||||
|
|
||||||
|
CREATE_LOG="/tmp/tpu_create_${QR_NAME}.log"
|
||||||
|
|
||||||
gcloud compute tpus queued-resources create $QR_NAME \
|
gcloud compute tpus queued-resources create $QR_NAME \
|
||||||
--project=$PROJECT_ID \
|
--project=$PROJECT_ID \
|
||||||
--node-id=$QR_NAME \
|
--node-id=$QR_NAME \
|
||||||
@@ -104,20 +106,23 @@ while true; do
|
|||||||
--accelerator-type=$ACCEL_TYPE \
|
--accelerator-type=$ACCEL_TYPE \
|
||||||
--runtime-version=$RT_VERSION \
|
--runtime-version=$RT_VERSION \
|
||||||
$SPOT_FLAG \
|
$SPOT_FLAG \
|
||||||
|
--internal-ips \
|
||||||
--metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \
|
--metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \
|
||||||
--metadata "$METADATA" 2>&1 | tee /tmp/tpu_create_${QR_NAME}.log
|
--metadata "$METADATA" 2>&1 | tee "$CREATE_LOG"
|
||||||
|
|
||||||
if [ $? -eq 0 ]; then
|
CREATE_EXIT=${PIPESTATUS[0]}
|
||||||
|
|
||||||
|
if [ $CREATE_EXIT -eq 0 ]; then
|
||||||
echo "[$(date)] Successfully queued $QR_NAME."
|
echo "[$(date)] Successfully queued $QR_NAME."
|
||||||
RETRY_DELAY=60
|
RETRY_DELAY=60
|
||||||
elif grep -q "IN_USE_ADDRESSES" /tmp/tpu_create_${QR_NAME}.log 2>/dev/null; then
|
elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then
|
||||||
echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s"
|
echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s"
|
||||||
sleep $RETRY_DELAY
|
sleep $RETRY_DELAY
|
||||||
RETRY_DELAY=$((RETRY_DELAY * 2))
|
RETRY_DELAY=$((RETRY_DELAY * 2))
|
||||||
[ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY
|
[ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY
|
||||||
continue
|
continue
|
||||||
else
|
else
|
||||||
echo "[$(date)] Failed to queue $QR_NAME."
|
echo "[$(date)] Failed to queue $QR_NAME (exit=$CREATE_EXIT)."
|
||||||
RETRY_DELAY=60
|
RETRY_DELAY=60
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
|
|||||||
Reference in New Issue
Block a user