feat: training update

This commit is contained in:
2026-02-27 09:33:04 +01:00
parent dac1e58a0d
commit e50d643fbf
3 changed files with 122 additions and 23 deletions

View File

@@ -26,6 +26,8 @@ RETRY_SECONDS ?= 20
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
TPU_NAME ?=
TPU_ZONE ?= us-central2-b
TPU_PROJECT ?= phantom-trc
TPU_REPO_DIR ?= /tmp/PHANTOM
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
@@ -33,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
.PHONY: help
help:
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines"
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
@echo "docker.train.publish"
@echo ""
@echo "Local wandb run:"
@@ -165,12 +167,47 @@ train.tpu.pod:
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
@$(SWEEP_ENV_LOAD); \
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all \
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
.PHONY: train.tpu.vm.prepare
train.tpu.vm.prepare:
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" \
LOCAL_REPO_DIR="$(CURDIR)" REMOTE_REPO_DIR="$(TPU_REPO_DIR)" \
sh scripts/tpu_sync_repo.sh
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh $(TPU_NAME):/tmp/tpu_vm_train.sh \
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
.PHONY: train.tpu.vm.run
train.tpu.vm.run:
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
@test -n "$(LOCAL_TRAIN_ARGS)" || (echo "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000" && exit 1)
@$(SWEEP_ENV_LOAD); \
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
--command="REPO_DIR='$(TPU_REPO_DIR)' TRAIN_ARGS='$(LOCAL_TRAIN_ARGS)' WANDB_API_KEY='$$WANDB_API_KEY' sh /tmp/tpu_vm_train.sh"
.PHONY: train.tpu.vm
train.tpu.vm: train.tpu.vm.prepare train.tpu.vm.run
.PHONY: train.tpu.vm.sweep
train.tpu.vm.sweep:
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" && exit 1)
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
@$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" \
python3 scripts/tpu_vm_sweep_agent.py \
--sweep-id "$(SWEEP_ID)" \
--tpu-name "$(TPU_NAME)" \
--tpu-zone "$(TPU_ZONE)" \
--tpu-project "$(TPU_PROJECT)" \
--tpu-repo-dir "$(TPU_REPO_DIR)" \
$(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),)
.PHONY: pdf clean watch run.webapp test count-lines all
pdf: pdf.build
clean: pdf.clean