mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactored training approaches
This commit is contained in:
291
Makefile
291
Makefile
@@ -8,57 +8,44 @@ VENV := .venv
|
|||||||
PYTHON := $(VENV)/bin/python
|
PYTHON := $(VENV)/bin/python
|
||||||
PIP := $(VENV)/bin/pip
|
PIP := $(VENV)/bin/pip
|
||||||
PYTEST := $(VENV)/bin/pytest
|
PYTEST := $(VENV)/bin/pytest
|
||||||
TPU_NAME ?= phantom-tpu
|
|
||||||
TPU_ZONE ?= us-central2-b
|
SWEEP_ENV_FILE ?= .env.sweep
|
||||||
TPU_TYPE ?= v4-32
|
|
||||||
TPU_RUNTIME ?= tpu-vm-v4-base
|
WANDB_ENTITY ?=
|
||||||
TPU_PROJECT ?= phantom-trc
|
WANDB_PROJECT ?= phantom-pricing
|
||||||
TPU_NETWORK ?= tpu-network
|
|
||||||
TPU_SUBNETWORK ?= tpu-network
|
|
||||||
TPU_USE_SPOT ?= 0
|
|
||||||
TPU_EXTRA_CREATE_FLAGS ?=
|
|
||||||
TPU_WORKDIR ?= ~/PHANTOM
|
|
||||||
TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
|
|
||||||
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
|
|
||||||
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
||||||
TPU_VENV ?= .venv-tpu
|
|
||||||
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online
|
|
||||||
SWEEP_ID ?=
|
SWEEP_ID ?=
|
||||||
SWEEP_COUNT ?= 5
|
LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000
|
||||||
QUEUE_SCRIPT ?= scripts/queue_sweep.sh
|
AGENT_COUNT ?= 0
|
||||||
TPU_QUEUE_TYPE ?=
|
|
||||||
TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b
|
REPO_URL ?=
|
||||||
TPU_QUEUE_REUSE_EXISTING ?= 1
|
BRANCH ?= main
|
||||||
TPU_QUEUE_KEEP_ALIVE ?= 1
|
WORKDIR ?= $(HOME)/PHANTOM-agent
|
||||||
TPU_QUEUE_STRICT_QUOTA ?= 0
|
AGENT_LOOP ?= 1
|
||||||
TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1
|
RETRY_SECONDS ?= 20
|
||||||
TPU_QUEUE_FILTER_ZONE ?=
|
|
||||||
TPU_QUEUE_FILTER_TYPE ?=
|
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
||||||
TPU_QUEUE_EXECUTION_MODE ?= venv
|
TPU_NAME ?=
|
||||||
TPU_QUEUE_SYNC_METHOD ?= tar
|
TPU_ZONE ?= us-central2-b
|
||||||
TPU_QUEUE_SKIP_SYNC ?= 0
|
|
||||||
TPU_QUEUE_DOCKER_IMAGE ?=
|
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
||||||
TPU_QUEUE_DOCKER_PULL ?= 1
|
|
||||||
TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1
|
|
||||||
TPU_QUEUE_SSH_BATCH_MODE ?= 1
|
|
||||||
TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12
|
|
||||||
TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine
|
|
||||||
TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1
|
|
||||||
TPU_QUEUE_AUTO_SSH_ADD ?= 1
|
|
||||||
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
|
|
||||||
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
|
|
||||||
|
|
||||||
.DEFAULT_GOAL := help
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*"
|
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines"
|
||||||
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
|
@echo "docker.train.publish"
|
||||||
@echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep"
|
@echo ""
|
||||||
@echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a"
|
@echo "Local wandb run:"
|
||||||
@echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io/<project>/<image>:tag"
|
@echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'"
|
||||||
@echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1"
|
@echo ""
|
||||||
@echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first"
|
@echo "Local sweep agent from this repo:"
|
||||||
|
@echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5"
|
||||||
|
@echo ""
|
||||||
|
@echo "Bootstrap private repo worker from anywhere:"
|
||||||
|
@echo " make train.bootstrap REPO_URL=https://github.com/org/repo.git BRANCH=main SWEEP_ID=entity/project/id"
|
||||||
|
@echo ""
|
||||||
|
@echo "Config source: $(SWEEP_ENV_FILE) (auto-loaded)"
|
||||||
|
|
||||||
$(BUILDDIR):
|
$(BUILDDIR):
|
||||||
mkdir -p paper/$(BUILDDIR)
|
mkdir -p paper/$(BUILDDIR)
|
||||||
@@ -115,173 +102,39 @@ $(VENV):
|
|||||||
install: $(VENV)
|
install: $(VENV)
|
||||||
$(PIP) install -r requirements.txt
|
$(PIP) install -r requirements.txt
|
||||||
|
|
||||||
.PHONY: tpu.setup
|
.PHONY: train
|
||||||
tpu.setup:
|
train: install
|
||||||
@command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1)
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
@gcloud auth login --update-adc
|
@$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \
|
||||||
@gcloud auth application-default login
|
$(PYTHON) -m engine.train $(LOCAL_TRAIN_ARGS)
|
||||||
@gcloud config set project "$(TPU_PROJECT)"
|
|
||||||
|
|
||||||
.PHONY: tpu.check.zone
|
.PHONY: train.agent
|
||||||
tpu.check.zone:
|
train.agent: install
|
||||||
@case "$(TPU_ZONE)" in \
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \
|
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
||||||
*) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \
|
@$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \
|
||||||
esac
|
$(PYTHON) -m engine.train --sweep-agent --sweep-id "$(SWEEP_ID)" \
|
||||||
|
$(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),)
|
||||||
|
|
||||||
.PHONY: tpu.create.v4.ondemand
|
.PHONY: train.bootstrap
|
||||||
tpu.create.v4.ondemand:
|
train.bootstrap:
|
||||||
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
|
@$(SWEEP_ENV_LOAD); test -n "$$GITHUB_TOKEN" || (echo "GITHUB_TOKEN required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
.PHONY: tpu.create.v4.spot
|
@test -n "$(REPO_URL)" || (echo "REPO_URL required, e.g. REPO_URL=https://github.com/org/repo.git" && exit 1)
|
||||||
tpu.create.v4.spot:
|
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
||||||
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network
|
@$(SWEEP_ENV_LOAD); \
|
||||||
|
WANDB_API_KEY="$$WANDB_API_KEY" \
|
||||||
.PHONY: tpu.create
|
WANDB_ENTITY="$(WANDB_ENTITY)" \
|
||||||
tpu.create: tpu.check.zone
|
WANDB_PROJECT="$(WANDB_PROJECT)" \
|
||||||
@if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \
|
GITHUB_TOKEN="$$GITHUB_TOKEN" \
|
||||||
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \
|
REPO_URL="$(REPO_URL)" \
|
||||||
echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \
|
BRANCH="$(BRANCH)" \
|
||||||
else \
|
WORKDIR="$(WORKDIR)" \
|
||||||
$(TPU_CREATE_CMD); \
|
SWEEP_ID="$(SWEEP_ID)" \
|
||||||
fi
|
AGENT_COUNT="$(AGENT_COUNT)" \
|
||||||
|
AGENT_LOOP="$(AGENT_LOOP)" \
|
||||||
.PHONY: tpu.ensure
|
RETRY_SECONDS="$(RETRY_SECONDS)" \
|
||||||
tpu.ensure: tpu.check.zone
|
bash scripts/wandb_agent_bootstrap.sh
|
||||||
@set -e; \
|
|
||||||
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \
|
|
||||||
if [ -z "$$STATE" ]; then \
|
|
||||||
echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \
|
|
||||||
$(TPU_CREATE_CMD); \
|
|
||||||
elif [ "$$STATE" = "READY" ]; then \
|
|
||||||
echo "TPU VM $(TPU_NAME) is READY"; \
|
|
||||||
elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \
|
|
||||||
echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \
|
|
||||||
$(TPU_CREATE_CMD); \
|
|
||||||
else \
|
|
||||||
echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \
|
|
||||||
exit 1; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
.PHONY: tpu.status
|
|
||||||
tpu.status:
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)"
|
|
||||||
|
|
||||||
.PHONY: tpu.ssh
|
|
||||||
tpu.ssh:
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)"
|
|
||||||
|
|
||||||
.PHONY: tpu.prepare
|
|
||||||
tpu.prepare: tpu.ensure
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)"
|
|
||||||
|
|
||||||
.PHONY: tpu.deploy
|
|
||||||
tpu.deploy: tpu.prepare
|
|
||||||
@for p in $(TPU_SYNC_PATHS); do \
|
|
||||||
if [ ! -e "$$p" ]; then continue; fi; \
|
|
||||||
if [ -d "$$p" ]; then \
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
|
|
||||||
else \
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
|
|
||||||
fi; \
|
|
||||||
done
|
|
||||||
|
|
||||||
.PHONY: tpu.install
|
|
||||||
tpu.install: tpu.ensure
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)'
|
|
||||||
|
|
||||||
.PHONY: tpu.check.remote
|
|
||||||
tpu.check.remote: tpu.ensure
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)'
|
|
||||||
|
|
||||||
.PHONY: tpu.train
|
|
||||||
tpu.train: tpu.check.remote
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)'
|
|
||||||
|
|
||||||
.PHONY: tpu.bootstrap
|
|
||||||
tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
|
|
||||||
|
|
||||||
.PHONY: tpu.delete
|
|
||||||
tpu.delete:
|
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.sweep
|
|
||||||
tpu.queue.sweep:
|
|
||||||
@set -e; \
|
|
||||||
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
|
|
||||||
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
|
|
||||||
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
|
|
||||||
if ! ssh-add -l >/dev/null 2>&1; then \
|
|
||||||
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
|
|
||||||
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
|
|
||||||
fi; \
|
|
||||||
fi; \
|
|
||||||
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.worker
|
|
||||||
tpu.queue.worker:
|
|
||||||
@set -e; \
|
|
||||||
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
|
|
||||||
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
|
|
||||||
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
|
|
||||||
if ! ssh-add -l >/dev/null 2>&1; then \
|
|
||||||
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
|
|
||||||
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
|
|
||||||
fi; \
|
|
||||||
fi; \
|
|
||||||
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.sweep.docker
|
|
||||||
tpu.queue.sweep.docker:
|
|
||||||
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
|
||||||
@$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.worker.docker
|
|
||||||
tpu.queue.worker.docker:
|
|
||||||
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
|
||||||
@$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.docker.build
|
|
||||||
tpu.queue.docker.build:
|
|
||||||
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
|
||||||
docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" .
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.docker.push
|
|
||||||
tpu.queue.docker.push:
|
|
||||||
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
|
||||||
docker push "$(TPU_QUEUE_DOCKER_IMAGE)"
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.status
|
|
||||||
tpu.queue.status:
|
|
||||||
@set -e; \
|
|
||||||
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
|
|
||||||
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
|
|
||||||
else \
|
|
||||||
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
|
|
||||||
fi; \
|
|
||||||
for ZONE in $(TPU_QUEUE_ZONES); do \
|
|
||||||
echo "--- $$ZONE ---"; \
|
|
||||||
if ! $$QCMD list --zone="$$ZONE"; then \
|
|
||||||
echo "Skipping $$ZONE (unavailable or no permission)"; \
|
|
||||||
fi; \
|
|
||||||
done
|
|
||||||
|
|
||||||
.PHONY: tpu.queue.clean
|
|
||||||
tpu.queue.clean:
|
|
||||||
@set -e; \
|
|
||||||
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
|
|
||||||
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
|
|
||||||
else \
|
|
||||||
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
|
|
||||||
fi; \
|
|
||||||
for ZONE in $(TPU_QUEUE_ZONES); do \
|
|
||||||
$$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \
|
|
||||||
case "$$NAME" in \
|
|
||||||
qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \
|
|
||||||
esac; \
|
|
||||||
done; \
|
|
||||||
done
|
|
||||||
|
|
||||||
.PHONY: stats.lines
|
.PHONY: stats.lines
|
||||||
stats.lines:
|
stats.lines:
|
||||||
@@ -299,6 +152,24 @@ wordcount:
|
|||||||
$(SRCDIR)/chapters/05-discussion.tex \
|
$(SRCDIR)/chapters/05-discussion.tex \
|
||||||
$(SRCDIR)/chapters/06-conclusion.tex
|
$(SRCDIR)/chapters/06-conclusion.tex
|
||||||
|
|
||||||
|
.PHONY: docker.train.publish
|
||||||
|
docker.train.publish:
|
||||||
|
docker build -f docker/Trainer.dockerfile --target gpu -t $(TRAIN_IMAGE_REF):gpu-latest .
|
||||||
|
docker push $(TRAIN_IMAGE_REF):gpu-latest
|
||||||
|
docker build -f docker/Trainer.dockerfile --target tpu -t $(TRAIN_IMAGE_REF):tpu-latest .
|
||||||
|
docker push $(TRAIN_IMAGE_REF):tpu-latest
|
||||||
|
|
||||||
|
.PHONY: train.tpu.pod
|
||||||
|
train.tpu.pod:
|
||||||
|
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||||
|
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
||||||
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
|
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
|
||||||
|
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all
|
||||||
|
@$(SWEEP_ENV_LOAD); \
|
||||||
|
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
||||||
|
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all \
|
||||||
|
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
|
||||||
|
|
||||||
.PHONY: pdf clean watch run.webapp test count-lines all
|
.PHONY: pdf clean watch run.webapp test count-lines all
|
||||||
pdf: pdf.build
|
pdf: pdf.build
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ COPY engine /app/engine
|
|||||||
ENV PYTHONPATH=/app \
|
ENV PYTHONPATH=/app \
|
||||||
PHANTOM_USE_JAX=1 \
|
PHANTOM_USE_JAX=1 \
|
||||||
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
|
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
|
||||||
JAX_PLATFORMS=tpu,cpu \
|
|
||||||
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||||
|
|||||||
@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
|
|||||||
n_states: int,
|
n_states: int,
|
||||||
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
||||||
k_actor, k_product, k_step = jax.random.split(key, 3)
|
k_actor, k_product, k_step = jax.random.split(key, 3)
|
||||||
|
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
|
||||||
|
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
|
||||||
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
||||||
actors = (actor_draw < alpha).astype(jnp.int32)
|
actors = (actor_draw < alpha).astype(jnp.int32)
|
||||||
products = jax.random.randint(
|
products = jax.random.randint(
|
||||||
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
|
|||||||
)
|
)
|
||||||
|
|
||||||
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
||||||
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
|
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
|
||||||
|
|
||||||
def _scan_step(carry, _):
|
def _scan_step(carry, _):
|
||||||
states, active, rng = carry
|
states, active, rng = carry
|
||||||
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
|
|||||||
probs_a = agent_T[states]
|
probs_a = agent_T[states]
|
||||||
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
||||||
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
||||||
next_state = jnp.where(active, next_state, int(term_idx))
|
next_state = jnp.where(active, next_state, term_idx_i32)
|
||||||
emitted = jnp.where(active, next_state, -1)
|
emitted = jnp.where(active, next_state, -1)
|
||||||
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
||||||
next_active = active & (~is_terminal)
|
next_active = active & (~is_terminal)
|
||||||
carry_states = jnp.where(next_active, next_state, int(term_idx))
|
carry_states = jnp.where(next_active, next_state, term_idx_i32)
|
||||||
return (carry_states, next_active, rng), emitted
|
return (carry_states, next_active, rng), emitted
|
||||||
|
|
||||||
_, state_t = jax.lax.scan(
|
_, state_t = jax.lax.scan(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
flax>=0.8.0
|
flax==0.10.7
|
||||||
optax>=0.2.0
|
optax==0.2.7
|
||||||
distrax>=0.1.5
|
distrax==0.1.5
|
||||||
orbax-checkpoint>=0.5.0
|
orbax-checkpoint==0.11.32
|
||||||
chex>=0.1.8
|
chex==0.1.90
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,11 +3,16 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parents[2]))
|
sys.path.insert(0, str(Path(__file__).parents[2]))
|
||||||
|
|
||||||
from sim.rl.behavior_loader.models import (
|
try:
|
||||||
BehaviorModel,
|
from sim.rl.behavior_loader.models import (
|
||||||
AgentBehaviorModel,
|
BehaviorModel,
|
||||||
aggregate_event_transitions,
|
AgentBehaviorModel,
|
||||||
)
|
aggregate_event_transitions,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
BehaviorModel = None
|
||||||
|
AgentBehaviorModel = None
|
||||||
|
aggregate_event_transitions = None
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .demand import generate_demand_for_actor
|
from .demand import generate_demand_for_actor
|
||||||
@@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots
|
|||||||
|
|
||||||
|
|
||||||
def _get_base_pivot(human: bool):
|
def _get_base_pivot(human: bool):
|
||||||
|
if (
|
||||||
|
BehaviorModel is None
|
||||||
|
or AgentBehaviorModel is None
|
||||||
|
or aggregate_event_transitions is None
|
||||||
|
):
|
||||||
|
raise ImportError("behavior loader dependencies are unavailable")
|
||||||
key = "human" if human else "agent"
|
key = "human" if human else "agent"
|
||||||
if key not in _cache:
|
if key not in _cache:
|
||||||
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
||||||
@@ -34,6 +45,13 @@ def get_transition_models():
|
|||||||
returns:
|
returns:
|
||||||
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
|
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
BehaviorModel is None
|
||||||
|
or AgentBehaviorModel is None
|
||||||
|
or aggregate_event_transitions is None
|
||||||
|
):
|
||||||
|
raise ImportError("behavior loader dependencies are unavailable")
|
||||||
|
|
||||||
human_model = BehaviorModel(human_dir)
|
human_model = BehaviorModel(human_dir)
|
||||||
agent_model = AgentBehaviorModel(agent_dir)
|
agent_model = AgentBehaviorModel(agent_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
|
|||||||
"JAX backend requested but JAX is not installed. "
|
"JAX backend requested but JAX is not installed. "
|
||||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||||
)
|
)
|
||||||
if algo == "qtable":
|
|
||||||
raise ValueError("qtable is not supported in JAX backend")
|
|
||||||
try:
|
try:
|
||||||
from .jax.train import train_jax
|
from .jax.train import train_jax
|
||||||
except Exception as exc: # pragma: no cover
|
except Exception as exc: # pragma: no cover
|
||||||
@@ -409,20 +407,25 @@ def run_wandb(
|
|||||||
init_kwargs = {"mode": mode}
|
init_kwargs = {"mode": mode}
|
||||||
if sweep_mode:
|
if sweep_mode:
|
||||||
run = wandb.init(**init_kwargs)
|
run = wandb.init(**init_kwargs)
|
||||||
cfg = _cfg(_wandb_cfg_dict())
|
|
||||||
for k, v in overrides.items():
|
|
||||||
if k not in wandb.config:
|
|
||||||
cfg[k] = v
|
|
||||||
else:
|
else:
|
||||||
run = wandb.init(project=project, config=overrides, **init_kwargs)
|
run = wandb.init(project=project, config=overrides, **init_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
cfg = _cfg(_wandb_cfg_dict())
|
cfg = _cfg(_wandb_cfg_dict())
|
||||||
metrics = train_once(cfg)
|
if sweep_mode:
|
||||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
for k, v in overrides.items():
|
||||||
wandb.log(metrics, step=step)
|
if k not in wandb.config:
|
||||||
for k, v in metrics.items():
|
cfg[k] = v
|
||||||
run.summary[k] = v
|
|
||||||
wandb.finish()
|
metrics = train_once(cfg)
|
||||||
return metrics
|
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||||
|
wandb.log(metrics, step=step)
|
||||||
|
for k, v in metrics.items():
|
||||||
|
run.summary[k] = v
|
||||||
|
return metrics
|
||||||
|
finally:
|
||||||
|
if wandb.run is not None:
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
def run_local(overrides: dict) -> dict:
|
def run_local(overrides: dict) -> dict:
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ The architecture of this platform begins with the deployed web-apps posting inte
|
|||||||
|
|
||||||
\paragraph{Public Web Artifact} We transition the Kappa like architecture of the data collection to a Lambda system for actual learning in a surrogate environment. This allows us to move faster on data which is provided and helps us create a feedback loop for production deployment. To support further research in this intersection of fields we release P4P \footnote{\url{https://github.com/velocitatem/p4p}} as a public repository providing the interaction layer of the PHANTOM framework. This provides a configurable storefront which can be tailored to any commercial setting with a standardized session-level event tracking. We document the API adapters or what the framework expects in terms of schemas for pricing providers and log ingestion servicse. The repository is intended for controlled experimentation and method replication rather than production commerce deployment.
|
\paragraph{Public Web Artifact} We transition the Kappa like architecture of the data collection to a Lambda system for actual learning in a surrogate environment. This allows us to move faster on data which is provided and helps us create a feedback loop for production deployment. To support further research in this intersection of fields we release P4P \footnote{\url{https://github.com/velocitatem/p4p}} as a public repository providing the interaction layer of the PHANTOM framework. This provides a configurable storefront which can be tailored to any commercial setting with a standardized session-level event tracking. We document the API adapters or what the framework expects in terms of schemas for pricing providers and log ingestion servicse. The repository is intended for controlled experimentation and method replication rather than production commerce deployment.
|
||||||
|
|
||||||
|
|
||||||
\subsubsection{DevOps Principles}
|
\subsubsection{DevOps Principles}
|
||||||
|
|
||||||
Reproducible results are key to quality research platforms, this is taken into mind when deploying and working with our research platform. From a deployment standpoint the platform can be deployed across a large variety of providers and can be run locally. When developing a new interaction modality apart from the ones that come out of the box, a simple template pattern can be followed. The middleware of the framework is designed to properly render the chosen modality from environmental variables, thus deployment of different or parallel version of the software can be easily parametrized.
|
Reproducible results are key to quality research platforms, this is taken into mind when deploying and working with our research platform. From a deployment standpoint the platform can be deployed across a large variety of providers and can be run locally. When developing a new interaction modality apart from the ones that come out of the box, a simple template pattern can be followed. The middleware of the framework is designed to properly render the chosen modality from environmental variables, thus deployment of different or parallel version of the software can be easily parametrized.
|
||||||
@@ -235,7 +236,11 @@ v4 & 64 (32 + 32) & us-central2-b & 32 Spot + 32 On-demand \\
|
|||||||
\end{tabular}
|
\end{tabular}
|
||||||
\end{table}
|
\end{table}
|
||||||
|
|
||||||
For interactive monitoring from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}.
|
For connections from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs with the benefit of having the most grouped chips within a single region. This regional grouping is important for the deployment of our Kubernetes cluster which cannot span multiple regions. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}.
|
||||||
|
|
||||||
|
Design of training processes: we build docker image with the fact in mind of different caching over layers in order to most speed up docker re-building and such we place the most volatile steps towards the end of the image building. What is means in practice is that any dependency installations are isolated so edits to source code do no trigger rebuilds. Only if we update our entry point of training a sweep, Docker will also rebuild the source-code copy stage.
|
||||||
|
|
||||||
|
Due to the preemptive nature of the current demand of TPU chips we sttle for running our on demeaned as the primary source of compute. The on demand TPU pod of 32 chips spread across 4 virtual hosts creates a relatively unique parallelization setup. Despite our desire to use a traditional approach of clustering and perhaps deploying SLURM jobs of our sweep agent, the lack of predictability in provisioning each instance of a compute resource makes this an high friction layer we do not want to add.
|
||||||
|
|
||||||
\subsubsection{Interaction Schema}
|
\subsubsection{Interaction Schema}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user