chore: bulk tpu reorchestration

This commit is contained in:
2026-03-15 21:14:41 +01:00
parent 52b4dcdce3
commit a9c091050c
10 changed files with 155 additions and 42 deletions

View File

@@ -40,23 +40,35 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT
PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS") PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
fi fi
elif [ "$CRED_TYPE" = "authorized_user" ]; then elif [ "$CRED_TYPE" = "authorized_user" ]; then
echo "Authenticating gcloud using authorized_user refresh token..." echo "Using authorized_user credentials via credential file override..."
export CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE="$GOOGLE_APPLICATION_CREDENTIALS"
AUTH_ACCOUNT="$GCP_ACCOUNT" if gcloud auth print-access-token >/dev/null 2>&1; then
if [ -z "$AUTH_ACCOUNT" ]; then ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS") if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then
fi ACTIVE_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
if [ -z "$AUTH_ACCOUNT" ]; then fi
AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
fi
REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS") if [ -n "$ACTIVE_ACCOUNT" ] && [ "$ACTIVE_ACCOUNT" != "(unset)" ]; then
if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then echo "Using gcloud account: $ACTIVE_ACCOUNT"
echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token." else
exit 1 echo "Using gcloud credential override from $GOOGLE_APPLICATION_CREDENTIALS"
fi fi
else
echo "Warning: credential file override token check failed. Falling back to mounted gcloud config."
unset CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN" if [ -n "$GCP_ACCOUNT" ]; then
gcloud config set account "$GCP_ACCOUNT" >/dev/null 2>&1 || true
fi
ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then
echo "Error: no active gcloud account available. Run 'gcloud auth login' on host and mount ~/.config/gcloud, or use a service account key."
exit 1
fi
echo "Using gcloud account: $ACTIVE_ACCOUNT"
fi
else else
echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config." echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config."
fi fi

View File

@@ -92,9 +92,9 @@ def _truthy(value: str | bool | None) -> bool:
return str(value).strip().lower() in {"1", "true", "yes", "on"} return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _alive_nodes() -> list[tuple[str, str]]: def _alive_nodes() -> list[tuple[str, str, bool, float]]:
seen: set[str] = set() seen: set[str] = set()
nodes: list[tuple[str, str]] = [] nodes: list[tuple[str, str, bool, float]] = []
for node in ray.nodes(): for node in ray.nodes():
if not bool(node.get("Alive", False)): if not bool(node.get("Alive", False)):
continue continue
@@ -102,9 +102,50 @@ def _alive_nodes() -> list[tuple[str, str]]:
ip = str(node.get("NodeManagerAddress", "")).strip() ip = str(node.get("NodeManagerAddress", "")).strip()
if not node_id or not ip or node_id in seen: if not node_id or not ip or node_id in seen:
continue continue
resources = node.get("Resources", {}) or {}
is_head = bool(resources.get("node:__internal_head__", 0.0))
tpu = float(resources.get("TPU", 0.0))
seen.add(node_id) seen.add(node_id)
nodes.append((node_id, ip)) nodes.append((node_id, ip, is_head, tpu))
return sorted(nodes, key=lambda item: (item[1], item[0])) return sorted(nodes, key=lambda item: (item[1], item[2], -item[3], item[0]))
def _dedupe_nodes_for_tpu(
nodes: list[tuple[str, str, bool, float]],
) -> tuple[list[tuple[str, str]], list[dict[str, str | float | bool]]]:
selected: dict[str, tuple[str, str, bool, float]] = {}
dropped: list[dict[str, str | float | bool]] = []
def _score(item: tuple[str, str, bool, float]) -> tuple[int, float, str]:
node_id, _ip, is_head, tpu = item
return (1 if bool(is_head) else 0, -float(tpu), str(node_id))
for item in nodes:
node_id, ip, is_head, tpu = item
existing = selected.get(ip)
if existing is None:
selected[ip] = item
continue
keep, drop = (
(item, existing) if _score(item) < _score(existing) else (existing, item)
)
selected[ip] = keep
dropped.append(
{
"ip": str(ip),
"dropped_node_id": str(drop[0]),
"dropped_is_head": bool(drop[2]),
"dropped_tpu": float(drop[3]),
"kept_node_id": str(keep[0]),
"kept_is_head": bool(keep[2]),
"kept_tpu": float(keep[3]),
}
)
entries = [(node_id, ip) for ip, (node_id, _ip, _is_head, _tpu) in selected.items()]
entries.sort(key=lambda item: (item[1], item[0]))
return entries, dropped
def _benchmark_cells( def _benchmark_cells(
@@ -369,8 +410,19 @@ def _train_on_node(
env["JAX_PLATFORM_NAME"] = "cpu" env["JAX_PLATFORM_NAME"] = "cpu"
else: else:
env.pop("JAX_PLATFORM_NAME", None) env.pop("JAX_PLATFORM_NAME", None)
# Keep each train process in single-host mode to avoid accidental global stalls. if requested_platform == "tpu" and world_size > 1 and allow_multi_node_tpu:
env["CLOUD_TPU_TASK_ID"] = "0" env["CLOUD_TPU_TASK_ID"] = str(int(rank))
print(
{
"rank": int(rank),
"node_ip": str(node_ip),
"jax_platform": "tpu",
"cloud_tpu_task_id": str(env["CLOUD_TPU_TASK_ID"]),
}
)
else:
# Keep each process in single-host mode when TPU multi-host is disabled.
env["CLOUD_TPU_TASK_ID"] = "0"
if run_kind == "benchmark": if run_kind == "benchmark":
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0" env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
if wandb_entity: if wandb_entity:
@@ -508,12 +560,36 @@ def main() -> None:
ray.init(address="auto") ray.init(address="auto")
node_entries = _alive_nodes() node_records = _alive_nodes()
if not node_entries: if not node_records:
raise RuntimeError("No alive Ray nodes found") raise RuntimeError("No alive Ray nodes found")
if float(args.tpu_per_task) > 0.0:
node_entries, dropped = _dedupe_nodes_for_tpu(node_records)
if dropped:
print(
{
"tpu_host_dedupe": True,
"alive_ray_nodes": len(node_records),
"unique_tpu_hosts": len(node_entries),
"dropped": dropped,
}
)
else:
node_entries = [
(node_id, node_ip) for node_id, node_ip, _is_head, _tpu in node_records
]
requested = int(args.num_nodes) requested = int(args.num_nodes)
if requested > 0: if requested > 0:
if requested > len(node_entries):
print(
{
"requested_nodes": int(requested),
"available_nodes": int(len(node_entries)),
"note": "requested nodes exceed available hosts; capping",
}
)
node_entries = node_entries[:requested] node_entries = node_entries[:requested]
world_size = len(node_entries) world_size = len(node_entries)

View File

@@ -66,6 +66,7 @@ MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}"
WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}" WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}"
SWEEP_KIND="${SWEEP_KIND:-benchmark}" SWEEP_KIND="${SWEEP_KIND:-benchmark}"
SWEEP_METHOD="${SWEEP_METHOD:-random}" SWEEP_METHOD="${SWEEP_METHOD:-random}"
SWEEP_PROFILE="${SWEEP_PROFILE:-default}"
SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}" SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}"
AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}" AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}"
AGENT_COUNT="${AGENT_COUNT:-0}" AGENT_COUNT="${AGENT_COUNT:-0}"
@@ -180,6 +181,7 @@ PY
fi fi
SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \ SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \
--kind "$SWEEP_KIND" \ --kind "$SWEEP_KIND" \
--profile "$SWEEP_PROFILE" \
--project "$SWEEP_PROJECT" \ --project "$SWEEP_PROJECT" \
--entity "$SWEEP_ENTITY" \ --entity "$SWEEP_ENTITY" \
--method "$SWEEP_METHOD" \ --method "$SWEEP_METHOD" \
@@ -199,10 +201,22 @@ PY
fi fi
fi fi
SWEEP_RUN_KIND="$SWEEP_KIND"
if [ "$SWEEP_KIND" = "ppo_calibration" ] || [ "$SWEEP_KIND" = "ppo_block_a" ] || [ "$SWEEP_KIND" = "ppo_shift_screen" ]; then
SWEEP_RUN_KIND="benchmark"
fi
if [ "$SWEEP_KIND" = "ppo_rl_study" ]; then
SWEEP_RUN_KIND="train"
fi
if [ "$SWEEP_RUN_KIND" != "benchmark" ] && [ "$SWEEP_RUN_KIND" != "train" ]; then
echo "Unsupported SWEEP_KIND='$SWEEP_KIND' (expected 'benchmark', 'train', 'ppo_calibration', 'ppo_block_a', 'ppo_shift_screen', or 'ppo_rl_study')." >&2
exit 1
fi
DIST_ARGS=( DIST_ARGS=(
python python
scripts/ray_distributed_train.py scripts/ray_distributed_train.py
--run-kind "$SWEEP_KIND" --run-kind "$SWEEP_RUN_KIND"
--entry-args "$SWEEP_ENTRY_ARGS" --entry-args "$SWEEP_ENTRY_ARGS"
--num-nodes "${SWEEP_NUM_NODES}" --num-nodes "${SWEEP_NUM_NODES}"
--tpu-per-task "${TPU_PER_TASK:-0}" --tpu-per-task "${TPU_PER_TASK:-0}"
@@ -214,13 +228,17 @@ PY
--inner-threads "$INNER_THREADS" --inner-threads "$INNER_THREADS"
--worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}" --worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}"
) )
if [ "$SWEEP_KIND" = "benchmark" ]; then if [ "$SWEEP_RUN_KIND" = "benchmark" ]; then
DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}") DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}")
fi fi
if [ "${COMPARE_ROBUST:-0}" = "1" ]; then if [ "${COMPARE_ROBUST:-0}" = "1" ]; then
DIST_ARGS+=(--compare-robust) DIST_ARGS+=(--compare-robust)
fi fi
echo "SWEEP_ID=$SWEEP_ID_VALUE" echo "SWEEP_ID=$SWEEP_ID_VALUE"
if [ "$SWEEP_KIND" = "train" ] && [ "$SWEEP_PROFILE" = "robust_revenue" ]; then
echo "When this sweep finishes, compare best robust config vs no_robust with:"
echo "python scripts/wandb_compare_best.py --entity $SWEEP_ENTITY --project $SWEEP_PROJECT --sweep-id $SWEEP_ID_VALUE --submit --ray-no-wait"
fi
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}" "$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
exit 0 exit 0
fi fi

View File

@@ -3,6 +3,7 @@ QR_NAME="v4-32-us-ondemand"
ACCEL_TYPE="v4-32" ACCEL_TYPE="v4-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="false" IS_SPOT="false"
INTERNAL_IPS="false"
RUN_ID="phantom_v4_od_1" RUN_ID="phantom_v4_od_1"
HF_REPO="velocitatem/capstone" HF_REPO="velocitatem/capstone"
TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof"

View File

@@ -3,6 +3,7 @@ QR_NAME="v4-32-us-spot"
ACCEL_TYPE="v4-32" ACCEL_TYPE="v4-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="true" IS_SPOT="true"
INTERNAL_IPS="false"
RUN_ID="phantom_v4_spot_1" RUN_ID="phantom_v4_spot_1"
HF_REPO="velocitatem/capstone" HF_REPO="velocitatem/capstone"
TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof"

View File

@@ -1,6 +1,6 @@
ZONE="europe-west4-b" ZONE="europe-west4-b"
QR_NAME="v5e-64-eu-spot" QR_NAME="v5e-32-eu-spot"
ACCEL_TYPE="v5litepod-64" ACCEL_TYPE="v5litepod-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="true" IS_SPOT="true"
RUN_ID="phantom_v5e_eu_1" RUN_ID="phantom_v5e_eu_1"

View File

@@ -1,6 +1,6 @@
ZONE="us-central1-a" ZONE="us-central1-a"
QR_NAME="v5e-64-us-spot" QR_NAME="v5e-32-us-spot"
ACCEL_TYPE="v5litepod-64" ACCEL_TYPE="v5litepod-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="true" IS_SPOT="true"
RUN_ID="phantom_v5e_us_1" RUN_ID="phantom_v5e_us_1"

View File

@@ -1,6 +1,6 @@
ZONE="europe-west4-a" ZONE="europe-west4-a"
QR_NAME="v6e-64-eu-spot" QR_NAME="v6e-32-eu-spot"
ACCEL_TYPE="v6e-64" ACCEL_TYPE="v6e-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="true" IS_SPOT="true"
RUN_ID="phantom_v6e_eu_1" RUN_ID="phantom_v6e_eu_1"

View File

@@ -1,6 +1,6 @@
ZONE="us-east1-d" ZONE="us-east1-d"
QR_NAME="v6e-64-us-spot" QR_NAME="v6e-32-us-spot"
ACCEL_TYPE="v6e-64" ACCEL_TYPE="v6e-32"
RUNTIME_VERSION="tpu-ubuntu2204-base" RUNTIME_VERSION="tpu-ubuntu2204-base"
IS_SPOT="true" IS_SPOT="true"
RUN_ID="phantom_v6e_us_1" RUN_ID="phantom_v6e_us_1"

View File

@@ -58,7 +58,7 @@ RETRY_DELAY=60
MAX_RETRY_DELAY=300 MAX_RETRY_DELAY=300
while true; do while true; do
STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state)" 2>/dev/null) STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state.state)" 2>/dev/null)
if [ -z "$STATE" ] || [[ "$STATE" == *"SUSPENDED"* ]] || [[ "$STATE" == *"FAILED"* ]]; then if [ -z "$STATE" ] || [[ "$STATE" == *"SUSPENDED"* ]] || [[ "$STATE" == *"FAILED"* ]]; then
echo "[$(date)] Cluster '${STATE:-MISSING}' - cleaning IPs and re-queuing..." echo "[$(date)] Cluster '${STATE:-MISSING}' - cleaning IPs and re-queuing..."
@@ -85,6 +85,11 @@ while true; do
SPOT_FLAG="--spot" SPOT_FLAG="--spot"
fi fi
IP_FLAG="--internal-ips"
if [ "${INTERNAL_IPS:-true}" != "true" ]; then
IP_FLAG=""
fi
# Prepare metadata # Prepare metadata
METADATA="HF_TOKEN=$HF_TOKEN,RUN_ID=$RUN_ID,HF_REPO=$HF_REPO,ACCEL_TYPE=$ACCEL_TYPE,GITHUB_REPO=$GITHUB_REPO,BRANCH=$BRANCH" METADATA="HF_TOKEN=$HF_TOKEN,RUN_ID=$RUN_ID,HF_REPO=$HF_REPO,ACCEL_TYPE=$ACCEL_TYPE,GITHUB_REPO=$GITHUB_REPO,BRANCH=$BRANCH"
if [ -n "$WANDB_API_KEY" ]; then if [ -n "$WANDB_API_KEY" ]; then
@@ -106,7 +111,7 @@ while true; do
--accelerator-type=$ACCEL_TYPE \ --accelerator-type=$ACCEL_TYPE \
--runtime-version=$RT_VERSION \ --runtime-version=$RT_VERSION \
$SPOT_FLAG \ $SPOT_FLAG \
--internal-ips \ $IP_FLAG \
--metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \ --metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \
--metadata "$METADATA" 2>&1 | tee "$CREATE_LOG" --metadata "$METADATA" 2>&1 | tee "$CREATE_LOG"
@@ -115,8 +120,8 @@ while true; do
if [ $CREATE_EXIT -eq 0 ]; then if [ $CREATE_EXIT -eq 0 ]; then
echo "[$(date)] Successfully queued $QR_NAME." echo "[$(date)] Successfully queued $QR_NAME."
RETRY_DELAY=60 RETRY_DELAY=60
elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then elif grep -Eq "IN_USE_ADDRESSES|RESOURCE_EXHAUSTED|Quota limit|QUOTA_EXCEEDED" "$CREATE_LOG" 2>/dev/null; then
echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s" echo "[$(date)] Quota pressure detected - backing off ${RETRY_DELAY}s"
sleep $RETRY_DELAY sleep $RETRY_DELAY
RETRY_DELAY=$((RETRY_DELAY * 2)) RETRY_DELAY=$((RETRY_DELAY * 2))
[ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY [ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY