mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: bulk tpu reorchestration
This commit is contained in:
@@ -40,23 +40,35 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT
|
||||
PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||
fi
|
||||
elif [ "$CRED_TYPE" = "authorized_user" ]; then
|
||||
echo "Authenticating gcloud using authorized_user refresh token..."
|
||||
echo "Using authorized_user credentials via credential file override..."
|
||||
export CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE="$GOOGLE_APPLICATION_CREDENTIALS"
|
||||
|
||||
AUTH_ACCOUNT="$GCP_ACCOUNT"
|
||||
if [ -z "$AUTH_ACCOUNT" ]; then
|
||||
AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||
fi
|
||||
if [ -z "$AUTH_ACCOUNT" ]; then
|
||||
AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
|
||||
if gcloud auth print-access-token >/dev/null 2>&1; then
|
||||
ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
|
||||
if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then
|
||||
ACTIVE_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||
fi
|
||||
|
||||
REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS")
|
||||
if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then
|
||||
echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token."
|
||||
if [ -n "$ACTIVE_ACCOUNT" ] && [ "$ACTIVE_ACCOUNT" != "(unset)" ]; then
|
||||
echo "Using gcloud account: $ACTIVE_ACCOUNT"
|
||||
else
|
||||
echo "Using gcloud credential override from $GOOGLE_APPLICATION_CREDENTIALS"
|
||||
fi
|
||||
else
|
||||
echo "Warning: credential file override token check failed. Falling back to mounted gcloud config."
|
||||
unset CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
|
||||
|
||||
if [ -n "$GCP_ACCOUNT" ]; then
|
||||
gcloud config set account "$GCP_ACCOUNT" >/dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true)
|
||||
if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then
|
||||
echo "Error: no active gcloud account available. Run 'gcloud auth login' on host and mount ~/.config/gcloud, or use a service account key."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN"
|
||||
echo "Using gcloud account: $ACTIVE_ACCOUNT"
|
||||
fi
|
||||
else
|
||||
echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config."
|
||||
fi
|
||||
|
||||
@@ -92,9 +92,9 @@ def _truthy(value: str | bool | None) -> bool:
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _alive_nodes() -> list[tuple[str, str]]:
|
||||
def _alive_nodes() -> list[tuple[str, str, bool, float]]:
|
||||
seen: set[str] = set()
|
||||
nodes: list[tuple[str, str]] = []
|
||||
nodes: list[tuple[str, str, bool, float]] = []
|
||||
for node in ray.nodes():
|
||||
if not bool(node.get("Alive", False)):
|
||||
continue
|
||||
@@ -102,9 +102,50 @@ def _alive_nodes() -> list[tuple[str, str]]:
|
||||
ip = str(node.get("NodeManagerAddress", "")).strip()
|
||||
if not node_id or not ip or node_id in seen:
|
||||
continue
|
||||
resources = node.get("Resources", {}) or {}
|
||||
is_head = bool(resources.get("node:__internal_head__", 0.0))
|
||||
tpu = float(resources.get("TPU", 0.0))
|
||||
seen.add(node_id)
|
||||
nodes.append((node_id, ip))
|
||||
return sorted(nodes, key=lambda item: (item[1], item[0]))
|
||||
nodes.append((node_id, ip, is_head, tpu))
|
||||
return sorted(nodes, key=lambda item: (item[1], item[2], -item[3], item[0]))
|
||||
|
||||
|
||||
def _dedupe_nodes_for_tpu(
|
||||
nodes: list[tuple[str, str, bool, float]],
|
||||
) -> tuple[list[tuple[str, str]], list[dict[str, str | float | bool]]]:
|
||||
selected: dict[str, tuple[str, str, bool, float]] = {}
|
||||
dropped: list[dict[str, str | float | bool]] = []
|
||||
|
||||
def _score(item: tuple[str, str, bool, float]) -> tuple[int, float, str]:
|
||||
node_id, _ip, is_head, tpu = item
|
||||
return (1 if bool(is_head) else 0, -float(tpu), str(node_id))
|
||||
|
||||
for item in nodes:
|
||||
node_id, ip, is_head, tpu = item
|
||||
existing = selected.get(ip)
|
||||
if existing is None:
|
||||
selected[ip] = item
|
||||
continue
|
||||
|
||||
keep, drop = (
|
||||
(item, existing) if _score(item) < _score(existing) else (existing, item)
|
||||
)
|
||||
selected[ip] = keep
|
||||
dropped.append(
|
||||
{
|
||||
"ip": str(ip),
|
||||
"dropped_node_id": str(drop[0]),
|
||||
"dropped_is_head": bool(drop[2]),
|
||||
"dropped_tpu": float(drop[3]),
|
||||
"kept_node_id": str(keep[0]),
|
||||
"kept_is_head": bool(keep[2]),
|
||||
"kept_tpu": float(keep[3]),
|
||||
}
|
||||
)
|
||||
|
||||
entries = [(node_id, ip) for ip, (node_id, _ip, _is_head, _tpu) in selected.items()]
|
||||
entries.sort(key=lambda item: (item[1], item[0]))
|
||||
return entries, dropped
|
||||
|
||||
|
||||
def _benchmark_cells(
|
||||
@@ -369,7 +410,18 @@ def _train_on_node(
|
||||
env["JAX_PLATFORM_NAME"] = "cpu"
|
||||
else:
|
||||
env.pop("JAX_PLATFORM_NAME", None)
|
||||
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||
if requested_platform == "tpu" and world_size > 1 and allow_multi_node_tpu:
|
||||
env["CLOUD_TPU_TASK_ID"] = str(int(rank))
|
||||
print(
|
||||
{
|
||||
"rank": int(rank),
|
||||
"node_ip": str(node_ip),
|
||||
"jax_platform": "tpu",
|
||||
"cloud_tpu_task_id": str(env["CLOUD_TPU_TASK_ID"]),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Keep each process in single-host mode when TPU multi-host is disabled.
|
||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||
if run_kind == "benchmark":
|
||||
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
|
||||
@@ -508,12 +560,36 @@ def main() -> None:
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
node_entries = _alive_nodes()
|
||||
if not node_entries:
|
||||
node_records = _alive_nodes()
|
||||
if not node_records:
|
||||
raise RuntimeError("No alive Ray nodes found")
|
||||
|
||||
if float(args.tpu_per_task) > 0.0:
|
||||
node_entries, dropped = _dedupe_nodes_for_tpu(node_records)
|
||||
if dropped:
|
||||
print(
|
||||
{
|
||||
"tpu_host_dedupe": True,
|
||||
"alive_ray_nodes": len(node_records),
|
||||
"unique_tpu_hosts": len(node_entries),
|
||||
"dropped": dropped,
|
||||
}
|
||||
)
|
||||
else:
|
||||
node_entries = [
|
||||
(node_id, node_ip) for node_id, node_ip, _is_head, _tpu in node_records
|
||||
]
|
||||
|
||||
requested = int(args.num_nodes)
|
||||
if requested > 0:
|
||||
if requested > len(node_entries):
|
||||
print(
|
||||
{
|
||||
"requested_nodes": int(requested),
|
||||
"available_nodes": int(len(node_entries)),
|
||||
"note": "requested nodes exceed available hosts; capping",
|
||||
}
|
||||
)
|
||||
node_entries = node_entries[:requested]
|
||||
|
||||
world_size = len(node_entries)
|
||||
|
||||
@@ -66,6 +66,7 @@ MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}"
|
||||
WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}"
|
||||
SWEEP_KIND="${SWEEP_KIND:-benchmark}"
|
||||
SWEEP_METHOD="${SWEEP_METHOD:-random}"
|
||||
SWEEP_PROFILE="${SWEEP_PROFILE:-default}"
|
||||
SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}"
|
||||
AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}"
|
||||
AGENT_COUNT="${AGENT_COUNT:-0}"
|
||||
@@ -180,6 +181,7 @@ PY
|
||||
fi
|
||||
SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \
|
||||
--kind "$SWEEP_KIND" \
|
||||
--profile "$SWEEP_PROFILE" \
|
||||
--project "$SWEEP_PROJECT" \
|
||||
--entity "$SWEEP_ENTITY" \
|
||||
--method "$SWEEP_METHOD" \
|
||||
@@ -199,10 +201,22 @@ PY
|
||||
fi
|
||||
fi
|
||||
|
||||
SWEEP_RUN_KIND="$SWEEP_KIND"
|
||||
if [ "$SWEEP_KIND" = "ppo_calibration" ] || [ "$SWEEP_KIND" = "ppo_block_a" ] || [ "$SWEEP_KIND" = "ppo_shift_screen" ]; then
|
||||
SWEEP_RUN_KIND="benchmark"
|
||||
fi
|
||||
if [ "$SWEEP_KIND" = "ppo_rl_study" ]; then
|
||||
SWEEP_RUN_KIND="train"
|
||||
fi
|
||||
if [ "$SWEEP_RUN_KIND" != "benchmark" ] && [ "$SWEEP_RUN_KIND" != "train" ]; then
|
||||
echo "Unsupported SWEEP_KIND='$SWEEP_KIND' (expected 'benchmark', 'train', 'ppo_calibration', 'ppo_block_a', 'ppo_shift_screen', or 'ppo_rl_study')." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DIST_ARGS=(
|
||||
python
|
||||
scripts/ray_distributed_train.py
|
||||
--run-kind "$SWEEP_KIND"
|
||||
--run-kind "$SWEEP_RUN_KIND"
|
||||
--entry-args "$SWEEP_ENTRY_ARGS"
|
||||
--num-nodes "${SWEEP_NUM_NODES}"
|
||||
--tpu-per-task "${TPU_PER_TASK:-0}"
|
||||
@@ -214,13 +228,17 @@ PY
|
||||
--inner-threads "$INNER_THREADS"
|
||||
--worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}"
|
||||
)
|
||||
if [ "$SWEEP_KIND" = "benchmark" ]; then
|
||||
if [ "$SWEEP_RUN_KIND" = "benchmark" ]; then
|
||||
DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}")
|
||||
fi
|
||||
if [ "${COMPARE_ROBUST:-0}" = "1" ]; then
|
||||
DIST_ARGS+=(--compare-robust)
|
||||
fi
|
||||
echo "SWEEP_ID=$SWEEP_ID_VALUE"
|
||||
if [ "$SWEEP_KIND" = "train" ] && [ "$SWEEP_PROFILE" = "robust_revenue" ]; then
|
||||
echo "When this sweep finishes, compare best robust config vs no_robust with:"
|
||||
echo "python scripts/wandb_compare_best.py --entity $SWEEP_ENTITY --project $SWEEP_PROJECT --sweep-id $SWEEP_ID_VALUE --submit --ray-no-wait"
|
||||
fi
|
||||
"$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
@@ -3,6 +3,7 @@ QR_NAME="v4-32-us-ondemand"
|
||||
ACCEL_TYPE="v4-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="false"
|
||||
INTERNAL_IPS="false"
|
||||
RUN_ID="phantom_v4_od_1"
|
||||
HF_REPO="velocitatem/capstone"
|
||||
TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof"
|
||||
@@ -3,6 +3,7 @@ QR_NAME="v4-32-us-spot"
|
||||
ACCEL_TYPE="v4-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="true"
|
||||
INTERNAL_IPS="false"
|
||||
RUN_ID="phantom_v4_spot_1"
|
||||
HF_REPO="velocitatem/capstone"
|
||||
TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof"
|
||||
@@ -1,6 +1,6 @@
|
||||
ZONE="europe-west4-b"
|
||||
QR_NAME="v5e-64-eu-spot"
|
||||
ACCEL_TYPE="v5litepod-64"
|
||||
QR_NAME="v5e-32-eu-spot"
|
||||
ACCEL_TYPE="v5litepod-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="true"
|
||||
RUN_ID="phantom_v5e_eu_1"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ZONE="us-central1-a"
|
||||
QR_NAME="v5e-64-us-spot"
|
||||
ACCEL_TYPE="v5litepod-64"
|
||||
QR_NAME="v5e-32-us-spot"
|
||||
ACCEL_TYPE="v5litepod-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="true"
|
||||
RUN_ID="phantom_v5e_us_1"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ZONE="europe-west4-a"
|
||||
QR_NAME="v6e-64-eu-spot"
|
||||
ACCEL_TYPE="v6e-64"
|
||||
QR_NAME="v6e-32-eu-spot"
|
||||
ACCEL_TYPE="v6e-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="true"
|
||||
RUN_ID="phantom_v6e_eu_1"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
ZONE="us-east1-d"
|
||||
QR_NAME="v6e-64-us-spot"
|
||||
ACCEL_TYPE="v6e-64"
|
||||
QR_NAME="v6e-32-us-spot"
|
||||
ACCEL_TYPE="v6e-32"
|
||||
RUNTIME_VERSION="tpu-ubuntu2204-base"
|
||||
IS_SPOT="true"
|
||||
RUN_ID="phantom_v6e_us_1"
|
||||
|
||||
@@ -58,7 +58,7 @@ RETRY_DELAY=60
|
||||
MAX_RETRY_DELAY=300
|
||||
|
||||
while true; do
|
||||
STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state)" 2>/dev/null)
|
||||
STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state.state)" 2>/dev/null)
|
||||
|
||||
if [ -z "$STATE" ] || [[ "$STATE" == *"SUSPENDED"* ]] || [[ "$STATE" == *"FAILED"* ]]; then
|
||||
echo "[$(date)] Cluster '${STATE:-MISSING}' - cleaning IPs and re-queuing..."
|
||||
@@ -85,6 +85,11 @@ while true; do
|
||||
SPOT_FLAG="--spot"
|
||||
fi
|
||||
|
||||
IP_FLAG="--internal-ips"
|
||||
if [ "${INTERNAL_IPS:-true}" != "true" ]; then
|
||||
IP_FLAG=""
|
||||
fi
|
||||
|
||||
# Prepare metadata
|
||||
METADATA="HF_TOKEN=$HF_TOKEN,RUN_ID=$RUN_ID,HF_REPO=$HF_REPO,ACCEL_TYPE=$ACCEL_TYPE,GITHUB_REPO=$GITHUB_REPO,BRANCH=$BRANCH"
|
||||
if [ -n "$WANDB_API_KEY" ]; then
|
||||
@@ -106,7 +111,7 @@ while true; do
|
||||
--accelerator-type=$ACCEL_TYPE \
|
||||
--runtime-version=$RT_VERSION \
|
||||
$SPOT_FLAG \
|
||||
--internal-ips \
|
||||
$IP_FLAG \
|
||||
--metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \
|
||||
--metadata "$METADATA" 2>&1 | tee "$CREATE_LOG"
|
||||
|
||||
@@ -115,8 +120,8 @@ while true; do
|
||||
if [ $CREATE_EXIT -eq 0 ]; then
|
||||
echo "[$(date)] Successfully queued $QR_NAME."
|
||||
RETRY_DELAY=60
|
||||
elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then
|
||||
echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s"
|
||||
elif grep -Eq "IN_USE_ADDRESSES|RESOURCE_EXHAUSTED|Quota limit|QUOTA_EXCEEDED" "$CREATE_LOG" 2>/dev/null; then
|
||||
echo "[$(date)] Quota pressure detected - backing off ${RETRY_DELAY}s"
|
||||
sleep $RETRY_DELAY
|
||||
RETRY_DELAY=$((RETRY_DELAY * 2))
|
||||
[ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY
|
||||
|
||||
Reference in New Issue
Block a user