mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: training update
This commit is contained in:
43
Makefile
43
Makefile
@@ -26,6 +26,8 @@ RETRY_SECONDS ?= 20
|
||||
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
||||
TPU_NAME ?=
|
||||
TPU_ZONE ?= us-central2-b
|
||||
TPU_PROJECT ?= phantom-trc
|
||||
TPU_REPO_DIR ?= /tmp/PHANTOM
|
||||
|
||||
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
||||
|
||||
@@ -33,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines"
|
||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
|
||||
@echo "docker.train.publish"
|
||||
@echo ""
|
||||
@echo "Local wandb run:"
|
||||
@@ -165,12 +167,47 @@ train.tpu.pod:
|
||||
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
||||
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
|
||||
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all
|
||||
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
|
||||
@$(SWEEP_ENV_LOAD); \
|
||||
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
||||
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all \
|
||||
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
|
||||
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
|
||||
|
||||
.PHONY: train.tpu.vm.prepare
|
||||
train.tpu.vm.prepare:
|
||||
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||
TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" \
|
||||
LOCAL_REPO_DIR="$(CURDIR)" REMOTE_REPO_DIR="$(TPU_REPO_DIR)" \
|
||||
sh scripts/tpu_sync_repo.sh
|
||||
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh $(TPU_NAME):/tmp/tpu_vm_train.sh \
|
||||
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
|
||||
|
||||
.PHONY: train.tpu.vm.run
|
||||
train.tpu.vm.run:
|
||||
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||
@test -n "$(LOCAL_TRAIN_ARGS)" || (echo "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000" && exit 1)
|
||||
@$(SWEEP_ENV_LOAD); \
|
||||
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
||||
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
|
||||
--command="REPO_DIR='$(TPU_REPO_DIR)' TRAIN_ARGS='$(LOCAL_TRAIN_ARGS)' WANDB_API_KEY='$$WANDB_API_KEY' sh /tmp/tpu_vm_train.sh"
|
||||
|
||||
.PHONY: train.tpu.vm
|
||||
train.tpu.vm: train.tpu.vm.prepare train.tpu.vm.run
|
||||
|
||||
.PHONY: train.tpu.vm.sweep
|
||||
train.tpu.vm.sweep:
|
||||
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" && exit 1)
|
||||
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||
@$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" \
|
||||
python3 scripts/tpu_vm_sweep_agent.py \
|
||||
--sweep-id "$(SWEEP_ID)" \
|
||||
--tpu-name "$(TPU_NAME)" \
|
||||
--tpu-zone "$(TPU_ZONE)" \
|
||||
--tpu-project "$(TPU_PROJECT)" \
|
||||
--tpu-repo-dir "$(TPU_REPO_DIR)" \
|
||||
$(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),)
|
||||
|
||||
.PHONY: pdf clean watch run.webapp test count-lines all
|
||||
pdf: pdf.build
|
||||
clean: pdf.clean
|
||||
|
||||
Reference in New Issue
Block a user