diff --git a/docker/TPUWatchdog.dockerfile b/docker/TPUWatchdog.dockerfile index 66c0c3f..83358f1 100644 --- a/docker/TPUWatchdog.dockerfile +++ b/docker/TPUWatchdog.dockerfile @@ -40,23 +40,35 @@ if [ -n "$GOOGLE_APPLICATION_CREDENTIALS" ] && [ -f "$GOOGLE_APPLICATION_CREDENT PROJECT_ID=$(jq -r '.project_id // empty' "$GOOGLE_APPLICATION_CREDENTIALS") fi elif [ "$CRED_TYPE" = "authorized_user" ]; then - echo "Authenticating gcloud using authorized_user refresh token..." + echo "Using authorized_user credentials via credential file override..." + export CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE="$GOOGLE_APPLICATION_CREDENTIALS" - AUTH_ACCOUNT="$GCP_ACCOUNT" - if [ -z "$AUTH_ACCOUNT" ]; then - AUTH_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS") - fi - if [ -z "$AUTH_ACCOUNT" ]; then - AUTH_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true) - fi + if gcloud auth print-access-token >/dev/null 2>&1; then + ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true) + if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then + ACTIVE_ACCOUNT=$(jq -r '.account // empty' "$GOOGLE_APPLICATION_CREDENTIALS") + fi - REFRESH_TOKEN=$(jq -r '.refresh_token // empty' "$GOOGLE_APPLICATION_CREDENTIALS") - if [ -z "$AUTH_ACCOUNT" ] || [ -z "$REFRESH_TOKEN" ]; then - echo "Error: authorized_user credentials require GCP_ACCOUNT (or embedded account) and refresh_token." - exit 1 - fi + if [ -n "$ACTIVE_ACCOUNT" ] && [ "$ACTIVE_ACCOUNT" != "(unset)" ]; then + echo "Using gcloud account: $ACTIVE_ACCOUNT" + else + echo "Using gcloud credential override from $GOOGLE_APPLICATION_CREDENTIALS" + fi + else + echo "Warning: credential file override token check failed. Falling back to mounted gcloud config." + unset CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE - gcloud auth activate-refresh-token "$AUTH_ACCOUNT" "$REFRESH_TOKEN" + if [ -n "$GCP_ACCOUNT" ]; then + gcloud config set account "$GCP_ACCOUNT" >/dev/null 2>&1 || true + fi + + ACTIVE_ACCOUNT=$(gcloud config get-value account 2>/dev/null || true) + if [ -z "$ACTIVE_ACCOUNT" ] || [ "$ACTIVE_ACCOUNT" = "(unset)" ]; then + echo "Error: no active gcloud account available. Run 'gcloud auth login' on host and mount ~/.config/gcloud, or use a service account key." + exit 1 + fi + echo "Using gcloud account: $ACTIVE_ACCOUNT" + fi else echo "Warning: unsupported credential file type '$CRED_TYPE'. Falling back to mounted gcloud config." fi diff --git a/scripts/ray_distributed_train.py b/scripts/ray_distributed_train.py index 3395a8f..773fddd 100644 --- a/scripts/ray_distributed_train.py +++ b/scripts/ray_distributed_train.py @@ -92,9 +92,9 @@ def _truthy(value: str | bool | None) -> bool: return str(value).strip().lower() in {"1", "true", "yes", "on"} -def _alive_nodes() -> list[tuple[str, str]]: +def _alive_nodes() -> list[tuple[str, str, bool, float]]: seen: set[str] = set() - nodes: list[tuple[str, str]] = [] + nodes: list[tuple[str, str, bool, float]] = [] for node in ray.nodes(): if not bool(node.get("Alive", False)): continue @@ -102,9 +102,50 @@ def _alive_nodes() -> list[tuple[str, str]]: ip = str(node.get("NodeManagerAddress", "")).strip() if not node_id or not ip or node_id in seen: continue + resources = node.get("Resources", {}) or {} + is_head = bool(resources.get("node:__internal_head__", 0.0)) + tpu = float(resources.get("TPU", 0.0)) seen.add(node_id) - nodes.append((node_id, ip)) - return sorted(nodes, key=lambda item: (item[1], item[0])) + nodes.append((node_id, ip, is_head, tpu)) + return sorted(nodes, key=lambda item: (item[1], item[2], -item[3], item[0])) + + +def _dedupe_nodes_for_tpu( + nodes: list[tuple[str, str, bool, float]], +) -> tuple[list[tuple[str, str]], list[dict[str, str | float | bool]]]: + selected: dict[str, tuple[str, str, bool, float]] = {} + dropped: list[dict[str, str | float | bool]] = [] + + def _score(item: tuple[str, str, bool, float]) -> tuple[int, float, str]: + node_id, _ip, is_head, tpu = item + return (1 if bool(is_head) else 0, -float(tpu), str(node_id)) + + for item in nodes: + node_id, ip, is_head, tpu = item + existing = selected.get(ip) + if existing is None: + selected[ip] = item + continue + + keep, drop = ( + (item, existing) if _score(item) < _score(existing) else (existing, item) + ) + selected[ip] = keep + dropped.append( + { + "ip": str(ip), + "dropped_node_id": str(drop[0]), + "dropped_is_head": bool(drop[2]), + "dropped_tpu": float(drop[3]), + "kept_node_id": str(keep[0]), + "kept_is_head": bool(keep[2]), + "kept_tpu": float(keep[3]), + } + ) + + entries = [(node_id, ip) for ip, (node_id, _ip, _is_head, _tpu) in selected.items()] + entries.sort(key=lambda item: (item[1], item[0])) + return entries, dropped def _benchmark_cells( @@ -369,8 +410,19 @@ def _train_on_node( env["JAX_PLATFORM_NAME"] = "cpu" else: env.pop("JAX_PLATFORM_NAME", None) - # Keep each train process in single-host mode to avoid accidental global stalls. - env["CLOUD_TPU_TASK_ID"] = "0" + if requested_platform == "tpu" and world_size > 1 and allow_multi_node_tpu: + env["CLOUD_TPU_TASK_ID"] = str(int(rank)) + print( + { + "rank": int(rank), + "node_ip": str(node_ip), + "jax_platform": "tpu", + "cloud_tpu_task_id": str(env["CLOUD_TPU_TASK_ID"]), + } + ) + else: + # Keep each process in single-host mode when TPU multi-host is disabled. + env["CLOUD_TPU_TASK_ID"] = "0" if run_kind == "benchmark": env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0" if wandb_entity: @@ -508,12 +560,36 @@ def main() -> None: ray.init(address="auto") - node_entries = _alive_nodes() - if not node_entries: + node_records = _alive_nodes() + if not node_records: raise RuntimeError("No alive Ray nodes found") + if float(args.tpu_per_task) > 0.0: + node_entries, dropped = _dedupe_nodes_for_tpu(node_records) + if dropped: + print( + { + "tpu_host_dedupe": True, + "alive_ray_nodes": len(node_records), + "unique_tpu_hosts": len(node_entries), + "dropped": dropped, + } + ) + else: + node_entries = [ + (node_id, node_ip) for node_id, node_ip, _is_head, _tpu in node_records + ] + requested = int(args.num_nodes) if requested > 0: + if requested > len(node_entries): + print( + { + "requested_nodes": int(requested), + "available_nodes": int(len(node_entries)), + "note": "requested nodes exceed available hosts; capping", + } + ) node_entries = node_entries[:requested] world_size = len(node_entries) diff --git a/submit_ray_job.sh b/submit_ray_job.sh index a6065ec..b4a2630 100755 --- a/submit_ray_job.sh +++ b/submit_ray_job.sh @@ -66,6 +66,7 @@ MAX_HEAVY_WORKERS="${MAX_HEAVY_WORKERS:-3}" WORKER_CPUS="${WORKER_CPUS:-$((INNER_WORKERS * INNER_THREADS))}" SWEEP_KIND="${SWEEP_KIND:-benchmark}" SWEEP_METHOD="${SWEEP_METHOD:-random}" +SWEEP_PROFILE="${SWEEP_PROFILE:-default}" SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-0}" AGENTS_PER_NODE="${AGENTS_PER_NODE:-16}" AGENT_COUNT="${AGENT_COUNT:-0}" @@ -180,6 +181,7 @@ PY fi SWEEP_ID_VALUE="$($PY_SWEEP_BIN "$ROOT/scripts/wandb_create_sweep.py" \ --kind "$SWEEP_KIND" \ + --profile "$SWEEP_PROFILE" \ --project "$SWEEP_PROJECT" \ --entity "$SWEEP_ENTITY" \ --method "$SWEEP_METHOD" \ @@ -199,10 +201,22 @@ PY fi fi + SWEEP_RUN_KIND="$SWEEP_KIND" + if [ "$SWEEP_KIND" = "ppo_calibration" ] || [ "$SWEEP_KIND" = "ppo_block_a" ] || [ "$SWEEP_KIND" = "ppo_shift_screen" ]; then + SWEEP_RUN_KIND="benchmark" + fi + if [ "$SWEEP_KIND" = "ppo_rl_study" ]; then + SWEEP_RUN_KIND="train" + fi + if [ "$SWEEP_RUN_KIND" != "benchmark" ] && [ "$SWEEP_RUN_KIND" != "train" ]; then + echo "Unsupported SWEEP_KIND='$SWEEP_KIND' (expected 'benchmark', 'train', 'ppo_calibration', 'ppo_block_a', 'ppo_shift_screen', or 'ppo_rl_study')." >&2 + exit 1 + fi + DIST_ARGS=( python scripts/ray_distributed_train.py - --run-kind "$SWEEP_KIND" + --run-kind "$SWEEP_RUN_KIND" --entry-args "$SWEEP_ENTRY_ARGS" --num-nodes "${SWEEP_NUM_NODES}" --tpu-per-task "${TPU_PER_TASK:-0}" @@ -214,13 +228,17 @@ PY --inner-threads "$INNER_THREADS" --worker-cpus "${WORKER_CPUS:-$((AGENTS_PER_NODE * INNER_THREADS))}" ) - if [ "$SWEEP_KIND" = "benchmark" ]; then + if [ "$SWEEP_RUN_KIND" = "benchmark" ]; then DIST_ARGS+=(--output-root "${OUTPUT_ROOT:-engine/studies/results/sweeps}") fi if [ "${COMPARE_ROBUST:-0}" = "1" ]; then DIST_ARGS+=(--compare-robust) fi echo "SWEEP_ID=$SWEEP_ID_VALUE" + if [ "$SWEEP_KIND" = "train" ] && [ "$SWEEP_PROFILE" = "robust_revenue" ]; then + echo "When this sweep finishes, compare best robust config vs no_robust with:" + echo "python scripts/wandb_compare_best.py --entity $SWEEP_ENTITY --project $SWEEP_PROJECT --sweep-id $SWEEP_ID_VALUE --submit --ray-no-wait" + fi "$RAY_BIN" "${COMMON_ARGS[@]}" "${DIST_ARGS[@]}" exit 0 fi diff --git a/tpu_orchestration/configs/v4_od_us.conf b/tpu_orchestration/configs/v4_od_us.conf index ba75d7f..42bda3e 100644 --- a/tpu_orchestration/configs/v4_od_us.conf +++ b/tpu_orchestration/configs/v4_od_us.conf @@ -3,6 +3,7 @@ QR_NAME="v4-32-us-ondemand" ACCEL_TYPE="v4-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="false" +INTERNAL_IPS="false" RUN_ID="phantom_v4_od_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/configs/v4_spot_us.conf b/tpu_orchestration/configs/v4_spot_us.conf index 2e31a18..25e9427 100644 --- a/tpu_orchestration/configs/v4_spot_us.conf +++ b/tpu_orchestration/configs/v4_spot_us.conf @@ -3,6 +3,7 @@ QR_NAME="v4-32-us-spot" ACCEL_TYPE="v4-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="true" +INTERNAL_IPS="false" RUN_ID="phantom_v4_spot_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/configs/v5e_eu.conf b/tpu_orchestration/configs/v5e_eu.conf index 89ef604..573cc5f 100644 --- a/tpu_orchestration/configs/v5e_eu.conf +++ b/tpu_orchestration/configs/v5e_eu.conf @@ -1,8 +1,8 @@ ZONE="europe-west4-b" -QR_NAME="v5e-64-eu-spot" -ACCEL_TYPE="v5litepod-64" +QR_NAME="v5e-32-eu-spot" +ACCEL_TYPE="v5litepod-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="true" RUN_ID="phantom_v5e_eu_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/configs/v5e_us.conf b/tpu_orchestration/configs/v5e_us.conf index a77c50e..c212eac 100644 --- a/tpu_orchestration/configs/v5e_us.conf +++ b/tpu_orchestration/configs/v5e_us.conf @@ -1,8 +1,8 @@ ZONE="us-central1-a" -QR_NAME="v5e-64-us-spot" -ACCEL_TYPE="v5litepod-64" +QR_NAME="v5e-32-us-spot" +ACCEL_TYPE="v5litepod-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="true" RUN_ID="phantom_v5e_us_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/configs/v6e_eu.conf b/tpu_orchestration/configs/v6e_eu.conf index ae7bcc3..55d3e3e 100644 --- a/tpu_orchestration/configs/v6e_eu.conf +++ b/tpu_orchestration/configs/v6e_eu.conf @@ -1,8 +1,8 @@ ZONE="europe-west4-a" -QR_NAME="v6e-64-eu-spot" -ACCEL_TYPE="v6e-64" +QR_NAME="v6e-32-eu-spot" +ACCEL_TYPE="v6e-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="true" RUN_ID="phantom_v6e_eu_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/configs/v6e_us.conf b/tpu_orchestration/configs/v6e_us.conf index a5fe55d..8145d3d 100644 --- a/tpu_orchestration/configs/v6e_us.conf +++ b/tpu_orchestration/configs/v6e_us.conf @@ -1,8 +1,8 @@ ZONE="us-east1-d" -QR_NAME="v6e-64-us-spot" -ACCEL_TYPE="v6e-64" +QR_NAME="v6e-32-us-spot" +ACCEL_TYPE="v6e-32" RUNTIME_VERSION="tpu-ubuntu2204-base" IS_SPOT="true" RUN_ID="phantom_v6e_us_1" HF_REPO="velocitatem/capstone" -TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" \ No newline at end of file +TRAIN_CMD="python -m engine.train --sweep-agent --sweep-id lusiana/capstone/oasdorof" diff --git a/tpu_orchestration/watchdog.sh b/tpu_orchestration/watchdog.sh index 7e7a0fc..1a01447 100755 --- a/tpu_orchestration/watchdog.sh +++ b/tpu_orchestration/watchdog.sh @@ -58,7 +58,7 @@ RETRY_DELAY=60 MAX_RETRY_DELAY=300 while true; do - STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state)" 2>/dev/null) + STATE=$(gcloud compute tpus queued-resources describe $QR_NAME --zone=$ZONE --project=$PROJECT_ID --format="value(state.state)" 2>/dev/null) if [ -z "$STATE" ] || [[ "$STATE" == *"SUSPENDED"* ]] || [[ "$STATE" == *"FAILED"* ]]; then echo "[$(date)] Cluster '${STATE:-MISSING}' - cleaning IPs and re-queuing..." @@ -84,6 +84,11 @@ while true; do if [ "$IS_SPOT" = "true" ]; then SPOT_FLAG="--spot" fi + + IP_FLAG="--internal-ips" + if [ "${INTERNAL_IPS:-true}" != "true" ]; then + IP_FLAG="" + fi # Prepare metadata METADATA="HF_TOKEN=$HF_TOKEN,RUN_ID=$RUN_ID,HF_REPO=$HF_REPO,ACCEL_TYPE=$ACCEL_TYPE,GITHUB_REPO=$GITHUB_REPO,BRANCH=$BRANCH" @@ -106,7 +111,7 @@ while true; do --accelerator-type=$ACCEL_TYPE \ --runtime-version=$RT_VERSION \ $SPOT_FLAG \ - --internal-ips \ + $IP_FLAG \ --metadata-from-file startup-script=$(dirname $0)/tpu_startup.sh \ --metadata "$METADATA" 2>&1 | tee "$CREATE_LOG" @@ -115,8 +120,8 @@ while true; do if [ $CREATE_EXIT -eq 0 ]; then echo "[$(date)] Successfully queued $QR_NAME." RETRY_DELAY=60 - elif grep -q "IN_USE_ADDRESSES" "$CREATE_LOG" 2>/dev/null; then - echo "[$(date)] IP quota hit - backing off ${RETRY_DELAY}s" + elif grep -Eq "IN_USE_ADDRESSES|RESOURCE_EXHAUSTED|Quota limit|QUOTA_EXCEEDED" "$CREATE_LOG" 2>/dev/null; then + echo "[$(date)] Quota pressure detected - backing off ${RETRY_DELAY}s" sleep $RETRY_DELAY RETRY_DELAY=$((RETRY_DELAY * 2)) [ $RETRY_DELAY -gt $MAX_RETRY_DELAY ] && RETRY_DELAY=$MAX_RETRY_DELAY