refactored training approaches

This commit is contained in:
2026-02-19 18:23:08 +01:00
parent 5912062dc0
commit 1a9901f118
8 changed files with 947 additions and 308 deletions

291
Makefile
View File

@@ -8,57 +8,44 @@ VENV := .venv
PYTHON := $(VENV)/bin/python PYTHON := $(VENV)/bin/python
PIP := $(VENV)/bin/pip PIP := $(VENV)/bin/pip
PYTEST := $(VENV)/bin/pytest PYTEST := $(VENV)/bin/pytest
TPU_NAME ?= phantom-tpu
TPU_ZONE ?= us-central2-b SWEEP_ENV_FILE ?= .env.sweep
TPU_TYPE ?= v4-32
TPU_RUNTIME ?= tpu-vm-v4-base WANDB_ENTITY ?=
TPU_PROJECT ?= phantom-trc WANDB_PROJECT ?= phantom-pricing
TPU_NETWORK ?= tpu-network
TPU_SUBNETWORK ?= tpu-network
TPU_USE_SPOT ?= 0
TPU_EXTRA_CREATE_FLAGS ?=
TPU_WORKDIR ?= ~/PHANTOM
TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
TPU_VENV ?= .venv-tpu
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online
SWEEP_ID ?= SWEEP_ID ?=
SWEEP_COUNT ?= 5 LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000
QUEUE_SCRIPT ?= scripts/queue_sweep.sh AGENT_COUNT ?= 0
TPU_QUEUE_TYPE ?=
TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b REPO_URL ?=
TPU_QUEUE_REUSE_EXISTING ?= 1 BRANCH ?= main
TPU_QUEUE_KEEP_ALIVE ?= 1 WORKDIR ?= $(HOME)/PHANTOM-agent
TPU_QUEUE_STRICT_QUOTA ?= 0 AGENT_LOOP ?= 1
TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1 RETRY_SECONDS ?= 20
TPU_QUEUE_FILTER_ZONE ?=
TPU_QUEUE_FILTER_TYPE ?= TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
TPU_QUEUE_EXECUTION_MODE ?= venv TPU_NAME ?=
TPU_QUEUE_SYNC_METHOD ?= tar TPU_ZONE ?= us-central2-b
TPU_QUEUE_SKIP_SYNC ?= 0
TPU_QUEUE_DOCKER_IMAGE ?= SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
TPU_QUEUE_DOCKER_PULL ?= 1
TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1
TPU_QUEUE_SSH_BATCH_MODE ?= 1
TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12
TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine
TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1
TPU_QUEUE_AUTO_SSH_ADD ?= 1
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
.DEFAULT_GOAL := help .DEFAULT_GOAL := help
.PHONY: help .PHONY: help
help: help:
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*" @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines"
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot" @echo "docker.train.publish"
@echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep" @echo ""
@echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a" @echo "Local wandb run:"
@echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io/<project>/<image>:tag" @echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'"
@echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1" @echo ""
@echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first" @echo "Local sweep agent from this repo:"
@echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5"
@echo ""
@echo "Bootstrap private repo worker from anywhere:"
@echo " make train.bootstrap REPO_URL=https://github.com/org/repo.git BRANCH=main SWEEP_ID=entity/project/id"
@echo ""
@echo "Config source: $(SWEEP_ENV_FILE) (auto-loaded)"
$(BUILDDIR): $(BUILDDIR):
mkdir -p paper/$(BUILDDIR) mkdir -p paper/$(BUILDDIR)
@@ -115,173 +102,39 @@ $(VENV):
install: $(VENV) install: $(VENV)
$(PIP) install -r requirements.txt $(PIP) install -r requirements.txt
.PHONY: tpu.setup .PHONY: train
tpu.setup: train: install
@command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1) @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
@gcloud auth login --update-adc @$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \
@gcloud auth application-default login $(PYTHON) -m engine.train $(LOCAL_TRAIN_ARGS)
@gcloud config set project "$(TPU_PROJECT)"
.PHONY: tpu.check.zone .PHONY: train.agent
tpu.check.zone: train.agent: install
@case "$(TPU_ZONE)" in \ @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \ @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
*) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \ @$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \
esac $(PYTHON) -m engine.train --sweep-agent --sweep-id "$(SWEEP_ID)" \
$(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),)
.PHONY: tpu.create.v4.ondemand .PHONY: train.bootstrap
tpu.create.v4.ondemand: train.bootstrap:
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
@$(SWEEP_ENV_LOAD); test -n "$$GITHUB_TOKEN" || (echo "GITHUB_TOKEN required — set it in $(SWEEP_ENV_FILE)" && exit 1)
.PHONY: tpu.create.v4.spot @test -n "$(REPO_URL)" || (echo "REPO_URL required, e.g. REPO_URL=https://github.com/org/repo.git" && exit 1)
tpu.create.v4.spot: @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network @$(SWEEP_ENV_LOAD); \
WANDB_API_KEY="$$WANDB_API_KEY" \
.PHONY: tpu.create WANDB_ENTITY="$(WANDB_ENTITY)" \
tpu.create: tpu.check.zone WANDB_PROJECT="$(WANDB_PROJECT)" \
@if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \ GITHUB_TOKEN="$$GITHUB_TOKEN" \
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \ REPO_URL="$(REPO_URL)" \
echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \ BRANCH="$(BRANCH)" \
else \ WORKDIR="$(WORKDIR)" \
$(TPU_CREATE_CMD); \ SWEEP_ID="$(SWEEP_ID)" \
fi AGENT_COUNT="$(AGENT_COUNT)" \
AGENT_LOOP="$(AGENT_LOOP)" \
.PHONY: tpu.ensure RETRY_SECONDS="$(RETRY_SECONDS)" \
tpu.ensure: tpu.check.zone bash scripts/wandb_agent_bootstrap.sh
@set -e; \
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \
if [ -z "$$STATE" ]; then \
echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \
$(TPU_CREATE_CMD); \
elif [ "$$STATE" = "READY" ]; then \
echo "TPU VM $(TPU_NAME) is READY"; \
elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \
echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \
$(TPU_CREATE_CMD); \
else \
echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \
exit 1; \
fi
.PHONY: tpu.status
tpu.status:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)"
.PHONY: tpu.ssh
tpu.ssh:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)"
.PHONY: tpu.prepare
tpu.prepare: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)"
.PHONY: tpu.deploy
tpu.deploy: tpu.prepare
@for p in $(TPU_SYNC_PATHS); do \
if [ ! -e "$$p" ]; then continue; fi; \
if [ -d "$$p" ]; then \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
else \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
fi; \
done
.PHONY: tpu.install
tpu.install: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)'
.PHONY: tpu.check.remote
tpu.check.remote: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)'
.PHONY: tpu.train
tpu.train: tpu.check.remote
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)'
.PHONY: tpu.bootstrap
tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
.PHONY: tpu.delete
tpu.delete:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
.PHONY: tpu.queue.sweep
tpu.queue.sweep:
@set -e; \
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
if ! ssh-add -l >/dev/null 2>&1; then \
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
fi; \
fi; \
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
.PHONY: tpu.queue.worker
tpu.queue.worker:
@set -e; \
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
if ! ssh-add -l >/dev/null 2>&1; then \
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
fi; \
fi; \
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
.PHONY: tpu.queue.sweep.docker
tpu.queue.sweep.docker:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
@$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker
.PHONY: tpu.queue.worker.docker
tpu.queue.worker.docker:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
@$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker
.PHONY: tpu.queue.docker.build
tpu.queue.docker.build:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" .
.PHONY: tpu.queue.docker.push
tpu.queue.docker.push:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
docker push "$(TPU_QUEUE_DOCKER_IMAGE)"
.PHONY: tpu.queue.status
tpu.queue.status:
@set -e; \
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
else \
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
fi; \
for ZONE in $(TPU_QUEUE_ZONES); do \
echo "--- $$ZONE ---"; \
if ! $$QCMD list --zone="$$ZONE"; then \
echo "Skipping $$ZONE (unavailable or no permission)"; \
fi; \
done
.PHONY: tpu.queue.clean
tpu.queue.clean:
@set -e; \
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
else \
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
fi; \
for ZONE in $(TPU_QUEUE_ZONES); do \
$$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \
case "$$NAME" in \
qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \
esac; \
done; \
done
.PHONY: stats.lines .PHONY: stats.lines
stats.lines: stats.lines:
@@ -299,6 +152,24 @@ wordcount:
$(SRCDIR)/chapters/05-discussion.tex \ $(SRCDIR)/chapters/05-discussion.tex \
$(SRCDIR)/chapters/06-conclusion.tex $(SRCDIR)/chapters/06-conclusion.tex
.PHONY: docker.train.publish
docker.train.publish:
docker build -f docker/Trainer.dockerfile --target gpu -t $(TRAIN_IMAGE_REF):gpu-latest .
docker push $(TRAIN_IMAGE_REF):gpu-latest
docker build -f docker/Trainer.dockerfile --target tpu -t $(TRAIN_IMAGE_REF):tpu-latest .
docker push $(TRAIN_IMAGE_REF):tpu-latest
.PHONY: train.tpu.pod
train.tpu.pod:
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all
@$(SWEEP_ENV_LOAD); \
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all \
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
.PHONY: pdf clean watch run.webapp test count-lines all .PHONY: pdf clean watch run.webapp test count-lines all
pdf: pdf.build pdf: pdf.build

View File

@@ -37,7 +37,6 @@ COPY engine /app/engine
ENV PYTHONPATH=/app \ ENV PYTHONPATH=/app \
PHANTOM_USE_JAX=1 \ PHANTOM_USE_JAX=1 \
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \ PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
JAX_PLATFORMS=tpu,cpu \
XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_PREALLOCATE=false
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"] ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]

View File

@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
n_states: int, n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3) k_actor, k_product, k_step = jax.random.split(key, 3)
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
actor_draw = jax.random.uniform(k_actor, (n_sessions,)) actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32) actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint( products = jax.random.randint(
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
) )
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_) active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32) state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
def _scan_step(carry, _): def _scan_step(carry, _):
states, active, rng = carry states, active, rng = carry
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
probs_a = agent_T[states] probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a) probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1) next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, int(term_idx)) next_state = jnp.where(active, next_state, term_idx_i32)
emitted = jnp.where(active, next_state, -1) emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)] is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal) next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, int(term_idx)) carry_states = jnp.where(next_active, next_state, term_idx_i32)
return (carry_states, next_active, rng), emitted return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan( _, state_t = jax.lax.scan(

View File

@@ -1,5 +1,5 @@
flax>=0.8.0 flax==0.10.7
optax>=0.2.0 optax==0.2.7
distrax>=0.1.5 distrax==0.1.5
orbax-checkpoint>=0.5.0 orbax-checkpoint==0.11.32
chex>=0.1.8 chex==0.1.90

File diff suppressed because it is too large Load Diff

View File

@@ -3,11 +3,16 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).parents[2])) sys.path.insert(0, str(Path(__file__).parents[2]))
from sim.rl.behavior_loader.models import ( try:
BehaviorModel, from sim.rl.behavior_loader.models import (
AgentBehaviorModel, BehaviorModel,
aggregate_event_transitions, AgentBehaviorModel,
) aggregate_event_transitions,
)
except ImportError:
BehaviorModel = None
AgentBehaviorModel = None
aggregate_event_transitions = None
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from .demand import generate_demand_for_actor from .demand import generate_demand_for_actor
@@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots
def _get_base_pivot(human: bool): def _get_base_pivot(human: bool):
if (
BehaviorModel is None
or AgentBehaviorModel is None
or aggregate_event_transitions is None
):
raise ImportError("behavior loader dependencies are unavailable")
key = "human" if human else "agent" key = "human" if human else "agent"
if key not in _cache: if key not in _cache:
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir) model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
@@ -34,6 +45,13 @@ def get_transition_models():
returns: returns:
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
""" """
if (
BehaviorModel is None
or AgentBehaviorModel is None
or aggregate_event_transitions is None
):
raise ImportError("behavior loader dependencies are unavailable")
human_model = BehaviorModel(human_dir) human_model = BehaviorModel(human_dir)
agent_model = AgentBehaviorModel(agent_dir) agent_model = AgentBehaviorModel(agent_dir)

View File

@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
"JAX backend requested but JAX is not installed. " "JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs." "Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
) )
if algo == "qtable":
raise ValueError("qtable is not supported in JAX backend")
try: try:
from .jax.train import train_jax from .jax.train import train_jax
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
@@ -409,20 +407,25 @@ def run_wandb(
init_kwargs = {"mode": mode} init_kwargs = {"mode": mode}
if sweep_mode: if sweep_mode:
run = wandb.init(**init_kwargs) run = wandb.init(**init_kwargs)
cfg = _cfg(_wandb_cfg_dict())
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
else: else:
run = wandb.init(project=project, config=overrides, **init_kwargs) run = wandb.init(project=project, config=overrides, **init_kwargs)
try:
cfg = _cfg(_wandb_cfg_dict()) cfg = _cfg(_wandb_cfg_dict())
metrics = train_once(cfg) if sweep_mode:
step = int(metrics.get("train/global_step", cfg["total_timesteps"])) for k, v in overrides.items():
wandb.log(metrics, step=step) if k not in wandb.config:
for k, v in metrics.items(): cfg[k] = v
run.summary[k] = v
wandb.finish() metrics = train_once(cfg)
return metrics step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
return metrics
finally:
if wandb.run is not None:
wandb.finish()
def run_local(overrides: dict) -> dict: def run_local(overrides: dict) -> dict:

View File

@@ -140,6 +140,7 @@ The architecture of this platform begins with the deployed web-apps posting inte
\paragraph{Public Web Artifact} We transition the Kappa like architecture of the data collection to a Lambda system for actual learning in a surrogate environment. This allows us to move faster on data which is provided and helps us create a feedback loop for production deployment. To support further research in this intersection of fields we release P4P \footnote{\url{https://github.com/velocitatem/p4p}} as a public repository providing the interaction layer of the PHANTOM framework. This provides a configurable storefront which can be tailored to any commercial setting with a standardized session-level event tracking. We document the API adapters or what the framework expects in terms of schemas for pricing providers and log ingestion servicse. The repository is intended for controlled experimentation and method replication rather than production commerce deployment. \paragraph{Public Web Artifact} We transition the Kappa like architecture of the data collection to a Lambda system for actual learning in a surrogate environment. This allows us to move faster on data which is provided and helps us create a feedback loop for production deployment. To support further research in this intersection of fields we release P4P \footnote{\url{https://github.com/velocitatem/p4p}} as a public repository providing the interaction layer of the PHANTOM framework. This provides a configurable storefront which can be tailored to any commercial setting with a standardized session-level event tracking. We document the API adapters or what the framework expects in terms of schemas for pricing providers and log ingestion servicse. The repository is intended for controlled experimentation and method replication rather than production commerce deployment.
\subsubsection{DevOps Principles} \subsubsection{DevOps Principles}
Reproducible results are key to quality research platforms, this is taken into mind when deploying and working with our research platform. From a deployment standpoint the platform can be deployed across a large variety of providers and can be run locally. When developing a new interaction modality apart from the ones that come out of the box, a simple template pattern can be followed. The middleware of the framework is designed to properly render the chosen modality from environmental variables, thus deployment of different or parallel version of the software can be easily parametrized. Reproducible results are key to quality research platforms, this is taken into mind when deploying and working with our research platform. From a deployment standpoint the platform can be deployed across a large variety of providers and can be run locally. When developing a new interaction modality apart from the ones that come out of the box, a simple template pattern can be followed. The middleware of the framework is designed to properly render the chosen modality from environmental variables, thus deployment of different or parallel version of the software can be easily parametrized.
@@ -235,7 +236,11 @@ v4 & 64 (32 + 32) & us-central2-b & 32 Spot + 32 On-demand \\
\end{tabular} \end{tabular}
\end{table} \end{table}
For interactive monitoring from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}. For connections from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs with the benefit of having the most grouped chips within a single region. This regional grouping is important for the deployment of our Kubernetes cluster which cannot span multiple regions. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}.
Design of training processes: we build docker image with the fact in mind of different caching over layers in order to most speed up docker re-building and such we place the most volatile steps towards the end of the image building. What is means in practice is that any dependency installations are isolated so edits to source code do no trigger rebuilds. Only if we update our entry point of training a sweep, Docker will also rebuild the source-code copy stage.
Due to the preemptive nature of the current demand of TPU chips we sttle for running our on demeaned as the primary source of compute. The on demand TPU pod of 32 chips spread across 4 virtual hosts creates a relatively unique parallelization setup. Despite our desire to use a traditional approach of clustering and perhaps deploying SLURM jobs of our sweep agent, the lack of predictability in provisioning each instance of a compute resource makes this an high friction layer we do not want to add.
\subsubsection{Interaction Schema} \subsubsection{Interaction Schema}