cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -108,49 +108,6 @@ PY
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
docker push "$image_ref:gpu-latest"
docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" .
docker push "$image_ref:tpu-latest"
;;
train-tpu-pod)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh"
;;
train-tpu-vm-prepare)
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
TPU_NAME="$TPU_NAME" \
TPU_ZONE="${TPU_ZONE:-us-central2-b}" \
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \
LOCAL_REPO_DIR="$PWD" \
REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \
sh scripts/tpu_sync_repo.sh
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
;;
train-tpu-vm-run)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000"
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh"
;;
train-tpu-vm-sweep)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
args=(
--sweep-id "$SWEEP_ID"
--tpu-name "$TPU_NAME"
--tpu-zone "${TPU_ZONE:-us-central2-b}"
--tpu-project "${TPU_PROJECT:-phantom-trc}"
--tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}"
)
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
args+=(--count "$AGENT_COUNT")
fi
WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}"
;;
*)
printf '%s\n' "Unknown research command: $cmd" >&2