Files
PHANTOM/scripts/tpu_pod_run.sh
2026-02-27 12:45:46 +01:00

33 lines
1.3 KiB
Bash
Executable File

#!/usr/bin/env sh
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
# Authenticates with Artifact Registry using the VM's service account metadata token,
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
# Required env vars: WANDB_API_KEY, SWEEP_ID
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
set -eu
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
AGENT_COUNT="${AGENT_COUNT:-1}"
# use VM service account — no manual key needed on the pod
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
--password-stdin https://us-central1-docker.pkg.dev
sudo docker pull "$IMAGE"
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
# --network host lets JAX reach the other pod workers for distributed init
sudo docker run --rm \
--privileged \
--network host \
--volume /dev:/dev \
-e WANDB_API_KEY="$WANDB_API_KEY" \
-e SWEEP_ID="$SWEEP_ID" \
-e AGENT_COUNT="$AGENT_COUNT" \
"$IMAGE"