diff --git a/Makefile b/Makefile index 43c55ee..6a24577 100644 --- a/Makefile +++ b/Makefile @@ -8,57 +8,44 @@ VENV := .venv PYTHON := $(VENV)/bin/python PIP := $(VENV)/bin/pip PYTEST := $(VENV)/bin/pytest -TPU_NAME ?= phantom-tpu -TPU_ZONE ?= us-central2-b -TPU_TYPE ?= v4-32 -TPU_RUNTIME ?= tpu-vm-v4-base -TPU_PROJECT ?= phantom-trc -TPU_NETWORK ?= tpu-network -TPU_SUBNETWORK ?= tpu-network -TPU_USE_SPOT ?= 0 -TPU_EXTRA_CREATE_FLAGS ?= -TPU_WORKDIR ?= ~/PHANTOM -TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env -TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000 -TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html -TPU_VENV ?= .venv-tpu -TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online + +SWEEP_ENV_FILE ?= .env.sweep + +WANDB_ENTITY ?= +WANDB_PROJECT ?= phantom-pricing SWEEP_ID ?= -SWEEP_COUNT ?= 5 -QUEUE_SCRIPT ?= scripts/queue_sweep.sh -TPU_QUEUE_TYPE ?= -TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b -TPU_QUEUE_REUSE_EXISTING ?= 1 -TPU_QUEUE_KEEP_ALIVE ?= 1 -TPU_QUEUE_STRICT_QUOTA ?= 0 -TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1 -TPU_QUEUE_FILTER_ZONE ?= -TPU_QUEUE_FILTER_TYPE ?= -TPU_QUEUE_EXECUTION_MODE ?= venv -TPU_QUEUE_SYNC_METHOD ?= tar -TPU_QUEUE_SKIP_SYNC ?= 0 -TPU_QUEUE_DOCKER_IMAGE ?= -TPU_QUEUE_DOCKER_PULL ?= 1 -TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1 -TPU_QUEUE_SSH_BATCH_MODE ?= 1 -TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12 -TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine -TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1 -TPU_QUEUE_AUTO_SSH_ADD ?= 1 -TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,) -TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS) +LOCAL_TRAIN_ARGS ?= --algo ppo --total-timesteps 50000 +AGENT_COUNT ?= 0 + +REPO_URL ?= +BRANCH ?= main +WORKDIR ?= $(HOME)/PHANTOM-agent +AGENT_LOOP ?= 1 +RETRY_SECONDS ?= 20 + +TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer +TPU_NAME ?= +TPU_ZONE ?= us-central2-b + +SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a .DEFAULT_GOAL := help .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*" - @echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot" - @echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep" - @echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a" - @echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io//:tag" - @echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1" - @echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first" + @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines" + @echo "docker.train.publish" + @echo "" + @echo "Local wandb run:" + @echo " make train LOCAL_TRAIN_ARGS='--algo ppo --total-timesteps 50000'" + @echo "" + @echo "Local sweep agent from this repo:" + @echo " make train.agent SWEEP_ID=entity/project/id AGENT_COUNT=5" + @echo "" + @echo "Bootstrap private repo worker from anywhere:" + @echo " make train.bootstrap REPO_URL=https://github.com/org/repo.git BRANCH=main SWEEP_ID=entity/project/id" + @echo "" + @echo "Config source: $(SWEEP_ENV_FILE) (auto-loaded)" $(BUILDDIR): mkdir -p paper/$(BUILDDIR) @@ -115,173 +102,39 @@ $(VENV): install: $(VENV) $(PIP) install -r requirements.txt -.PHONY: tpu.setup -tpu.setup: - @command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1) - @gcloud auth login --update-adc - @gcloud auth application-default login - @gcloud config set project "$(TPU_PROJECT)" +.PHONY: train +train: install + @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) + @$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \ + $(PYTHON) -m engine.train $(LOCAL_TRAIN_ARGS) -.PHONY: tpu.check.zone -tpu.check.zone: - @case "$(TPU_ZONE)" in \ - europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \ - *) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \ - esac +.PHONY: train.agent +train.agent: install + @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) + @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1) + @$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" WANDB_ENTITY="$(WANDB_ENTITY)" WANDB_PROJECT="$(WANDB_PROJECT)" \ + $(PYTHON) -m engine.train --sweep-agent --sweep-id "$(SWEEP_ID)" \ + $(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),) -.PHONY: tpu.create.v4.ondemand -tpu.create.v4.ondemand: - $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network - -.PHONY: tpu.create.v4.spot -tpu.create.v4.spot: - $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network - -.PHONY: tpu.create -tpu.create: tpu.check.zone - @if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \ - STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \ - echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \ - else \ - $(TPU_CREATE_CMD); \ - fi - -.PHONY: tpu.ensure -tpu.ensure: tpu.check.zone - @set -e; \ - STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \ - if [ -z "$$STATE" ]; then \ - echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \ - $(TPU_CREATE_CMD); \ - elif [ "$$STATE" = "READY" ]; then \ - echo "TPU VM $(TPU_NAME) is READY"; \ - elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \ - echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \ - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \ - $(TPU_CREATE_CMD); \ - else \ - echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \ - exit 1; \ - fi - -.PHONY: tpu.status -tpu.status: - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" - -.PHONY: tpu.ssh -tpu.ssh: - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" - -.PHONY: tpu.prepare -tpu.prepare: tpu.ensure - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)" - -.PHONY: tpu.deploy -tpu.deploy: tpu.prepare - @for p in $(TPU_SYNC_PATHS); do \ - if [ ! -e "$$p" ]; then continue; fi; \ - if [ -d "$$p" ]; then \ - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \ - else \ - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \ - fi; \ - done - -.PHONY: tpu.install -tpu.install: tpu.ensure - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)' - -.PHONY: tpu.check.remote -tpu.check.remote: tpu.ensure - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)' - -.PHONY: tpu.train -tpu.train: tpu.check.remote - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)' - -.PHONY: tpu.bootstrap -tpu.bootstrap: tpu.ensure tpu.deploy tpu.install - -.PHONY: tpu.delete -tpu.delete: - gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet - -.PHONY: tpu.queue.sweep -tpu.queue.sweep: - @set -e; \ - test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \ - test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \ - if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \ - if ! ssh-add -l >/dev/null 2>&1; then \ - if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \ - ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \ - fi; \ - fi; \ - AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)" - -.PHONY: tpu.queue.worker -tpu.queue.worker: - @set -e; \ - test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \ - test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \ - if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \ - if ! ssh-add -l >/dev/null 2>&1; then \ - if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \ - ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \ - fi; \ - fi; \ - AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)" - -.PHONY: tpu.queue.sweep.docker -tpu.queue.sweep.docker: - @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) - @$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker - -.PHONY: tpu.queue.worker.docker -tpu.queue.worker.docker: - @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) - @$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker - -.PHONY: tpu.queue.docker.build -tpu.queue.docker.build: - @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) - docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" . - -.PHONY: tpu.queue.docker.push -tpu.queue.docker.push: - @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) - docker push "$(TPU_QUEUE_DOCKER_IMAGE)" - -.PHONY: tpu.queue.status -tpu.queue.status: - @set -e; \ - if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \ - QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \ - else \ - QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \ - fi; \ - for ZONE in $(TPU_QUEUE_ZONES); do \ - echo "--- $$ZONE ---"; \ - if ! $$QCMD list --zone="$$ZONE"; then \ - echo "Skipping $$ZONE (unavailable or no permission)"; \ - fi; \ - done - -.PHONY: tpu.queue.clean -tpu.queue.clean: - @set -e; \ - if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \ - QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \ - else \ - QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \ - fi; \ - for ZONE in $(TPU_QUEUE_ZONES); do \ - $$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \ - case "$$NAME" in \ - qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \ - esac; \ - done; \ - done +.PHONY: train.bootstrap +train.bootstrap: + @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) + @$(SWEEP_ENV_LOAD); test -n "$$GITHUB_TOKEN" || (echo "GITHUB_TOKEN required — set it in $(SWEEP_ENV_FILE)" && exit 1) + @test -n "$(REPO_URL)" || (echo "REPO_URL required, e.g. REPO_URL=https://github.com/org/repo.git" && exit 1) + @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1) + @$(SWEEP_ENV_LOAD); \ + WANDB_API_KEY="$$WANDB_API_KEY" \ + WANDB_ENTITY="$(WANDB_ENTITY)" \ + WANDB_PROJECT="$(WANDB_PROJECT)" \ + GITHUB_TOKEN="$$GITHUB_TOKEN" \ + REPO_URL="$(REPO_URL)" \ + BRANCH="$(BRANCH)" \ + WORKDIR="$(WORKDIR)" \ + SWEEP_ID="$(SWEEP_ID)" \ + AGENT_COUNT="$(AGENT_COUNT)" \ + AGENT_LOOP="$(AGENT_LOOP)" \ + RETRY_SECONDS="$(RETRY_SECONDS)" \ + bash scripts/wandb_agent_bootstrap.sh .PHONY: stats.lines stats.lines: @@ -299,6 +152,24 @@ wordcount: $(SRCDIR)/chapters/05-discussion.tex \ $(SRCDIR)/chapters/06-conclusion.tex +.PHONY: docker.train.publish +docker.train.publish: + docker build -f docker/Trainer.dockerfile --target gpu -t $(TRAIN_IMAGE_REF):gpu-latest . + docker push $(TRAIN_IMAGE_REF):gpu-latest + docker build -f docker/Trainer.dockerfile --target tpu -t $(TRAIN_IMAGE_REF):tpu-latest . + docker push $(TRAIN_IMAGE_REF):tpu-latest + +.PHONY: train.tpu.pod +train.tpu.pod: + @test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1) + @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1) + @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) + gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \ + --zone=$(TPU_ZONE) --project=phantom-trc --worker=all + @$(SWEEP_ENV_LOAD); \ + gcloud compute tpus tpu-vm ssh $(TPU_NAME) \ + --zone=$(TPU_ZONE) --project=phantom-trc --worker=all \ + --command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh" .PHONY: pdf clean watch run.webapp test count-lines all pdf: pdf.build diff --git a/docker/Trainer.dockerfile b/docker/Trainer.dockerfile index c6776ea..df50fed 100644 --- a/docker/Trainer.dockerfile +++ b/docker/Trainer.dockerfile @@ -37,7 +37,6 @@ COPY engine /app/engine ENV PYTHONPATH=/app \ PHANTOM_USE_JAX=1 \ PHANTOM_DEFAULT_AGENT_ARGS="--jax" \ - JAX_PLATFORMS=tpu,cpu \ XLA_PYTHON_CLIENT_PREALLOCATE=false ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"] diff --git a/engine/jax/primitives.py b/engine/jax/primitives.py index 8de4c2b..37bf326 100644 --- a/engine/jax/primitives.py +++ b/engine/jax/primitives.py @@ -308,6 +308,8 @@ if JAX_AVAILABLE: n_states: int, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: k_actor, k_product, k_step = jax.random.split(key, 3) + start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32) + term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32) actor_draw = jax.random.uniform(k_actor, (n_sessions,)) actors = (actor_draw < alpha).astype(jnp.int32) products = jax.random.randint( @@ -315,7 +317,7 @@ if JAX_AVAILABLE: ) active_init = jnp.ones((n_sessions,), dtype=jnp.bool_) - state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32) + state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32) def _scan_step(carry, _): states, active, rng = carry @@ -324,11 +326,11 @@ if JAX_AVAILABLE: probs_a = agent_T[states] probs = jnp.where(actors[:, None] == 0, probs_h, probs_a) next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1) - next_state = jnp.where(active, next_state, int(term_idx)) + next_state = jnp.where(active, next_state, term_idx_i32) emitted = jnp.where(active, next_state, -1) is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)] next_active = active & (~is_terminal) - carry_states = jnp.where(next_active, next_state, int(term_idx)) + carry_states = jnp.where(next_active, next_state, term_idx_i32) return (carry_states, next_active, rng), emitted _, state_t = jax.lax.scan( diff --git a/engine/jax/requirements.txt b/engine/jax/requirements.txt index 42ba457..7bde61c 100644 --- a/engine/jax/requirements.txt +++ b/engine/jax/requirements.txt @@ -1,5 +1,5 @@ -flax>=0.8.0 -optax>=0.2.0 -distrax>=0.1.5 -orbax-checkpoint>=0.5.0 -chex>=0.1.8 +flax==0.10.7 +optax==0.2.7 +distrax==0.1.5 +orbax-checkpoint==0.11.32 +chex==0.1.90 diff --git a/engine/jax/train.py b/engine/jax/train.py index 41678c1..408f9b3 100644 --- a/engine/jax/train.py +++ b/engine/jax/train.py @@ -1,12 +1,34 @@ -"""Pure JAX PPO trainer for the PHANTOM environment.""" +"""Pure JAX trainers for PHANTOM environment.""" from __future__ import annotations from pathlib import Path from typing import Any, NamedTuple +import signal +import threading + import numpy as np +_stop_requested = threading.Event() +_jax_dist_initialized = False + + +def _init_jax_distributed() -> None: + """Initialize JAX distributed if running on a multi-host TPU pod. + Safe to call multiple times; no-op after first successful init or when JAX unavailable.""" + global _jax_dist_initialized + if _jax_dist_initialized: + return + _jax_dist_initialized = True + try: + import jax as _jax + + _jax.distributed.initialize() + except Exception: + pass + + try: import wandb @@ -108,6 +130,33 @@ class ActorCritic(nn.Module): return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1) +class QNetwork(nn.Module): + action_dim: int + activation: str = "relu" + + @nn.compact + def __call__(self, x): + activation_fn = nn.relu if self.activation == "relu" else nn.tanh + x = nn.Dense( + 128, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(x) + x = activation_fn(x) + x = nn.Dense( + 128, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(x) + x = activation_fn(x) + q_values = nn.Dense( + self.action_dim, + kernel_init=orthogonal(1.0), + bias_init=constant(0.0), + )(x) + return q_values + + class Transition(NamedTuple): done: jax.Array action: jax.Array @@ -118,6 +167,24 @@ class Transition(NamedTuple): info: dict[str, jax.Array] +class ReplayBatch(NamedTuple): + obs: jax.Array + actions: jax.Array + rewards: jax.Array + next_obs: jax.Array + dones: jax.Array + + +class ReplayBuffer(NamedTuple): + obs: jax.Array + actions: jax.Array + rewards: jax.Array + next_obs: jax.Array + dones: jax.Array + ptr: jax.Array + size: jax.Array + + def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: out = { "algo": str(cfg.get("algo", "ppo")).lower(), @@ -133,6 +200,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: "total_timesteps": int(cfg.get("total_timesteps", 50_000)), "eval_episodes": int(cfg.get("eval_episodes", 5)), "model_dir": str(cfg.get("model_dir", "engine/models")), + "log_freq": int(cfg.get("log_freq", 100)), "n_products": int(cfg.get("n_products", 10)), "N": int(cfg.get("N", 100)), "alpha": float(cfg.get("alpha", 0.3)), @@ -156,6 +224,18 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: "update_epochs": int(cfg.get("jax_update_epochs", 4)), "anneal_lr": bool(cfg.get("jax_anneal_lr", True)), "checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)), + "buffer_size": int(cfg.get("buffer_size", 50_000)), + "batch_size": int(cfg.get("batch_size", 256)), + "train_freq": int(cfg.get("train_freq", 1)), + "learning_starts": int(cfg.get("learning_starts", 1_000)), + "target_update_interval": int(cfg.get("target_update_interval", 1_000)), + "exploration_fraction": float(cfg.get("exploration_fraction", 0.2)), + "exploration_final_eps": float(cfg.get("exploration_final_eps", 0.05)), + "eps_start": float(cfg.get("eps_start", 1.0)), + "eps_end": float(cfg.get("eps_end", 0.05)), + "eps_decay": float(cfg.get("eps_decay", 0.9995)), + "q_lr": float(cfg.get("q_lr", 0.1)), + "q_bins": int(cfg.get("q_bins", 6)), } rollout = out["num_envs"] * out["num_steps"] out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1)) @@ -163,15 +243,15 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: return out -def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array: - mask = done - while mask.ndim < keep.ndim: - mask = mask[..., None] - return jnp.where(mask, reset, keep) +def _scalar(value: Any) -> float: + return float(np.asarray(value)) -def make_train(config: dict[str, Any]): - cfg = _jax_cfg(config) +def _scalar_int(value: Any) -> int: + return int(np.asarray(value)) + + +def _make_env(cfg: dict[str, Any]) -> PHANTOMJAXEnv: env_params = make_env_params( n_products=cfg["n_products"], alpha=cfg["alpha"], @@ -191,7 +271,109 @@ def make_train(config: dict[str, Any]): margin_floor_patience=cfg["margin_floor_patience"], prefer_behavior_data=cfg["prefer_behavior_data"], ) - env = PHANTOMJAXEnv(env_params) + return PHANTOMJAXEnv(env_params) + + +def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array: + mask = done + while mask.ndim < keep.ndim: + mask = mask[..., None] + return jnp.where(mask, reset, keep) + + +def _epsilon_by_fraction(step: int, cfg: dict[str, Any]) -> float: + start = float(cfg["eps_start"]) + end = float(cfg["exploration_final_eps"]) + frac = float(cfg["exploration_fraction"]) + total = max(1, int(cfg["total_timesteps"])) + decay_steps = max(1, int(total * frac)) + if step >= decay_steps: + return end + slope = (end - start) / decay_steps + return float(start + slope * step) + + +def _digitize_scalar(value: jax.Array, bins: jax.Array) -> jax.Array: + return jnp.sum(value > bins).astype(jnp.int32) + + +def _encode_qtable_state( + obs: jax.Array, + *, + n_products: int, + demand_bins: jax.Array, + price_bins: jax.Array, +) -> tuple[jax.Array, jax.Array, jax.Array]: + demand = obs[:n_products] + prices = obs[n_products : 2 * n_products] + d_mean = jnp.mean(demand) + d_std = jnp.std(demand) + p_mean = jnp.mean(prices) + return ( + _digitize_scalar(d_mean, demand_bins), + _digitize_scalar(d_std, demand_bins), + _digitize_scalar(p_mean, price_bins), + ) + + +def _init_replay_buffer(capacity: int, obs_dim: int) -> ReplayBuffer: + cap = max(1, int(capacity)) + return ReplayBuffer( + obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32), + actions=jnp.zeros((cap,), dtype=jnp.int32), + rewards=jnp.zeros((cap,), dtype=jnp.float32), + next_obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32), + dones=jnp.zeros((cap,), dtype=jnp.float32), + ptr=jnp.asarray(0, dtype=jnp.int32), + size=jnp.asarray(0, dtype=jnp.int32), + ) + + +def _replay_size(buffer: ReplayBuffer) -> int: + return _scalar_int(buffer.size) + + +def _replay_add( + buffer: ReplayBuffer, + obs: jax.Array, + action: jax.Array, + reward: jax.Array, + next_obs: jax.Array, + done: jax.Array, +) -> ReplayBuffer: + capacity = int(buffer.obs.shape[0]) + idx = buffer.ptr % capacity + return ReplayBuffer( + obs=buffer.obs.at[idx].set(obs.astype(jnp.float32)), + actions=buffer.actions.at[idx].set(action.astype(jnp.int32)), + rewards=buffer.rewards.at[idx].set(reward.astype(jnp.float32)), + next_obs=buffer.next_obs.at[idx].set(next_obs.astype(jnp.float32)), + dones=buffer.dones.at[idx].set(done.astype(jnp.float32)), + ptr=buffer.ptr + 1, + size=jnp.minimum(buffer.size + 1, jnp.asarray(capacity, dtype=jnp.int32)), + ) + + +def _replay_sample( + buffer: ReplayBuffer, key: jax.Array, batch_size: int +) -> ReplayBatch: + size = jnp.maximum(buffer.size, 1) + idx = jax.random.randint(key, shape=(batch_size,), minval=0, maxval=size) + return ReplayBatch( + obs=buffer.obs[idx], + actions=buffer.actions[idx], + rewards=buffer.rewards[idx], + next_obs=buffer.next_obs[idx], + dones=buffer.dones[idx], + ) + + +def _make_actor_critic_train( + config: dict[str, Any], *, algo: str, use_pmap: bool = False +): + cfg = dict(config) + cfg["algo"] = algo + env = _make_env(cfg) network = ActorCritic(env.action_space_n(), activation=cfg["activation"]) def linear_schedule(count: jax.Array) -> jax.Array: @@ -299,39 +481,45 @@ def make_train(config: dict[str, Any]): def _loss_fn(params, traj_b, adv_b, tgt_b): policy, value = network.apply(params, traj_b.obs) log_prob = policy.log_prob(traj_b.action) - - value_clipped = traj_b.value + (value - traj_b.value).clip( - -cfg["clip_range"], cfg["clip_range"] - ) - value_loss = ( - 0.5 - * jnp.maximum( - jnp.square(value - tgt_b), - jnp.square(value_clipped - tgt_b), - ).mean() - ) - adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8) - ratio = jnp.exp(log_prob - traj_b.log_prob) - loss_actor = -jnp.minimum( - ratio * adv_norm, - jnp.clip( - ratio, - 1.0 - cfg["clip_range"], - 1.0 + cfg["clip_range"], + + if algo == "ppo": + value_clipped = traj_b.value + (value - traj_b.value).clip( + -cfg["clip_range"], cfg["clip_range"] ) - * adv_norm, - ).mean() + value_loss = ( + 0.5 + * jnp.maximum( + jnp.square(value - tgt_b), + jnp.square(value_clipped - tgt_b), + ).mean() + ) + ratio = jnp.exp(log_prob - traj_b.log_prob) + policy_loss = -jnp.minimum( + ratio * adv_norm, + jnp.clip( + ratio, + 1.0 - cfg["clip_range"], + 1.0 + cfg["clip_range"], + ) + * adv_norm, + ).mean() + else: + value_loss = 0.5 * jnp.mean(jnp.square(value - tgt_b)) + policy_loss = -(log_prob * adv_norm).mean() + entropy = policy.entropy().mean() total_loss = ( - loss_actor + policy_loss + cfg["vf_coef"] * value_loss - cfg["ent_coef"] * entropy ) - return total_loss, (value_loss, loss_actor, entropy) + return total_loss, (value_loss, policy_loss, entropy) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b) + if use_pmap: + grads = jax.lax.pmean(grads, axis_name="devices") train_state = train_state.apply_gradients(grads=grads) return train_state, jnp.asarray(0.0, dtype=jnp.float32) @@ -339,6 +527,7 @@ def make_train(config: dict[str, Any]): rng, perm_key = jax.random.split(rng) batch_size = cfg["num_envs"] * cfg["num_steps"] permutation = jax.random.permutation(perm_key, batch_size) + batch = (traj_batch, advantages, targets) batch = jax.tree_util.tree_map( lambda x: x.reshape((batch_size,) + x.shape[2:]), @@ -377,7 +566,7 @@ def make_train(config: dict[str, Any]): next_runner_state = (train_state, env_state, last_obs, rng) return next_runner_state, metric - def run_updates(runner_state, *, num_updates: int): + def run_updates(runner_state, num_updates: int): updates = max(1, int(num_updates)) runner_state, metric = jax.lax.scan( _update_step, @@ -393,6 +582,14 @@ def make_train(config: dict[str, Any]): return init_runner_state, run_updates, network, env, cfg +def make_train(config: dict[str, Any]): + cfg = _jax_cfg(config) + algo = cfg["algo"] + if algo not in {"ppo", "a2c"}: + raise ValueError(f"make_train supports actor-critic algos only, got '{algo}'") + return _make_actor_critic_train(cfg, algo=algo) + + def evaluate_policy( *, network: ActorCritic, @@ -418,8 +615,8 @@ def evaluate_policy( action = jnp.argmax(policy.logits) key, step_key = jax.random.split(key) obs, state, reward, done_flag, info = env.step(step_key, state, action) - ep_reward += float(np.asarray(reward)) - ep_revenue += float(np.asarray(info["revenue"])) + ep_reward += _scalar(reward) + ep_revenue += _scalar(info["revenue"]) done = bool(np.asarray(done_flag)) steps += 1 @@ -434,32 +631,130 @@ def evaluate_policy( } -def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: - if not HAS_JAX_STACK: - raise ImportError( - "JAX PPO path requires jax, flax, optax, and distrax. " - "Install engine/jax/requirements.txt on this machine first." - ) +def _evaluate_q_network( + *, + network: QNetwork, + params: Any, + env: PHANTOMJAXEnv, + episodes: int, + seed: int, +) -> dict[str, float]: + rewards: list[float] = [] + revenues: list[float] = [] + key = jax.random.PRNGKey(seed) - run_cfg = _jax_cfg(cfg) - if run_cfg["algo"] != "ppo": - raise ValueError( - f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'" - ) + for _ in range(int(episodes)): + key, reset_key = jax.random.split(key) + obs, state = env.reset(reset_key) + ep_reward = 0.0 + ep_revenue = 0.0 + done = False + steps = 0 + + while not done and steps < int(env.params.max_episode_steps): + q_values = network.apply(params, obs) + action = jnp.argmax(q_values) + key, step_key = jax.random.split(key) + obs, state, reward, done_flag, info = env.step(step_key, state, action) + ep_reward += _scalar(reward) + ep_revenue += _scalar(info["revenue"]) + done = bool(np.asarray(done_flag)) + steps += 1 + + rewards.append(ep_reward) + revenues.append(ep_revenue) + + return { + "eval/reward": float(np.mean(rewards)), + "eval/revenue": float(np.mean(revenues)), + "eval/reward_std": float(np.std(rewards)), + "eval/revenue_std": float(np.std(revenues)), + } + + +def _evaluate_q_table( + *, + q_table: jax.Array, + env: PHANTOMJAXEnv, + episodes: int, + seed: int, + n_products: int, + demand_bins: jax.Array, + price_bins: jax.Array, +) -> dict[str, float]: + rewards: list[float] = [] + revenues: list[float] = [] + key = jax.random.PRNGKey(seed) + + for _ in range(int(episodes)): + key, reset_key = jax.random.split(key) + obs, state = env.reset(reset_key) + ep_reward = 0.0 + ep_revenue = 0.0 + done = False + steps = 0 + + while not done and steps < int(env.params.max_episode_steps): + s0, s1, s2 = _encode_qtable_state( + obs, + n_products=n_products, + demand_bins=demand_bins, + price_bins=price_bins, + ) + action = jnp.argmax(q_table[s0, s1, s2]) + key, step_key = jax.random.split(key) + obs, state, reward, done_flag, info = env.step(step_key, state, action) + ep_reward += _scalar(reward) + ep_revenue += _scalar(info["revenue"]) + done = bool(np.asarray(done_flag)) + steps += 1 + + rewards.append(ep_reward) + revenues.append(ep_revenue) + + return { + "eval/reward": float(np.mean(rewards)), + "eval/revenue": float(np.mean(revenues)), + "eval/reward_std": float(np.std(rewards)), + "eval/revenue_std": float(np.std(revenues)), + } + + +def _train_actor_critic( + cfg: dict[str, Any], + *, + algo: str, +) -> tuple[dict[str, Any], dict[str, float]]: + num_devices = jax.local_device_count() + use_pmap = num_devices > 1 + + init_runner_state, run_updates_raw, network, env, run_cfg = ( + _make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap) + ) + + if use_pmap: + run_fn = jax.pmap( + run_updates_raw, + axis_name="devices", + static_broadcasted_argnums=(1,), + devices=jax.local_devices(), + ) + else: + run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",)) - init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg) - run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",)) rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"]) total_updates = int(run_cfg["num_updates"]) checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000))) segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1)) rng = jax.random.PRNGKey(run_cfg["seed"]) - runner_state = init_runner_state(rng) + # single-device state used as template for serialization and eval + single_runner_state = init_runner_state(rng) updates_done = 0 + is_primary = jax.process_index() == 0 artifact_name = None - if HAS_WANDB and wandb.run is not None: + if is_primary and HAS_WANDB and wandb.run is not None: sweep_id = getattr(wandb.run, "sweep_id", None) artifact_name = checkpoint_artifact_name( run_cfg, @@ -468,34 +763,48 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: ) restored = download_latest_checkpoint( artifact_name, - file_name="jax_runner_state.msgpack", + file_name=f"jax_{algo}_runner_state.msgpack", ) if restored is not None: checkpoint_path, metadata = restored - template = { - "runner_state": runner_state, - "updates_done": 0, - } + template = {"runner_state": single_runner_state, "updates_done": 0} payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) - runner_state = payload["runner_state"] + single_runner_state = payload["runner_state"] updates_done = int(payload.get("updates_done", 0)) if updates_done <= 0: updates_done = int(metadata.get("updates_done", 0)) updates_done = max(0, min(updates_done, total_updates)) + if use_pmap: + runner_state = jax.device_put_replicated( + single_runner_state, jax.local_devices() + ) + else: + runner_state = single_runner_state + metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"] metric_sums = {k: 0.0 for k in metric_keys} metric_count = 0 while updates_done < total_updates: updates_this_segment = min(segment_updates, total_updates - updates_done) - out = run_updates_jit(runner_state, num_updates=updates_this_segment) + if use_pmap: + out = run_fn(runner_state, updates_this_segment) + else: + out = run_fn(runner_state, updates_this_segment) runner_state = out["runner_state"] metric = out["metrics"] - segment_values = { - k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys - } + if use_pmap: + # take device-0 slice; shape is (n_devices, segment_updates) + segment_values = { + key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys + } + else: + segment_values = { + key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys + } + segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0 metric_count += segment_count for key in metric_keys: @@ -504,7 +813,7 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: updates_done += int(updates_this_segment) global_step = int(updates_done * rollout_steps) - if HAS_WANDB and wandb.run is not None: + if is_primary and HAS_WANDB and wandb.run is not None: wandb.log( { "train/reward": float(segment_values["reward"].mean()), @@ -517,25 +826,36 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: step=global_step, ) if artifact_name is not None: + # extract device-0 state for checkpoint portability + state_to_save = ( + jax.tree_util.tree_map(lambda x: x[0], runner_state) + if use_pmap + else runner_state + ) checkpoint_payload = serialization.to_bytes( - { - "runner_state": runner_state, - "updates_done": updates_done, - } + {"runner_state": state_to_save, "updates_done": updates_done} ) log_checkpoint_bytes( artifact_name, - file_name="jax_runner_state.msgpack", + file_name=f"jax_{algo}_runner_state.msgpack", payload=checkpoint_payload, metadata={ "step": global_step, "updates_done": updates_done, "rollout_steps": rollout_steps, - "algo": "ppo", + "algo": algo, }, ) + if _stop_requested.is_set(): + break - train_state = runner_state[0] + # extract device-0 params for eval and save + final_runner = ( + jax.tree_util.tree_map(lambda x: x[0], runner_state) + if use_pmap + else runner_state + ) + train_state = final_runner[0] denom = float(metric_count) if metric_count > 0 else 1.0 metrics = { "train/reward": float(metric_sums["reward"] / denom), @@ -555,9 +875,430 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: ) metrics.update(eval_metrics) + if is_primary: + model_dir = Path(run_cfg["model_dir"]) + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / f"phantom_{algo}_jax.msgpack" + model_path.write_bytes(serialization.to_bytes(train_state.params)) + metrics["model/path"] = str(model_path) + + return {"params": train_state.params}, metrics + + +def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: + run_cfg = dict(cfg) + env = _make_env(run_cfg) + action_dim = env.action_space_n() + obs_dim = env.observation_dim() + q_net = QNetwork(action_dim=action_dim, activation=run_cfg["activation"]) + + init_obs = jnp.zeros((obs_dim,), dtype=jnp.float32) + rng = jax.random.PRNGKey(run_cfg["seed"]) + rng, init_key = jax.random.split(rng) + params = q_net.init(init_key, init_obs) + tx = optax.adam(run_cfg["learning_rate"]) + train_state = TrainState.create(apply_fn=q_net.apply, params=params, tx=tx) + target_params = train_state.params + + buffer = _init_replay_buffer(run_cfg["buffer_size"], obs_dim) + + rng, reset_key = jax.random.split(rng) + obs, env_state = env.reset(reset_key) + + start_step = 0 + epsilon_value = float(run_cfg["eps_start"]) + artifact_name = None + + if HAS_WANDB and wandb.run is not None: + sweep_id = getattr(wandb.run, "sweep_id", None) + artifact_name = checkpoint_artifact_name( + run_cfg, + backend="jax", + sweep_id=sweep_id, + ) + restored = download_latest_checkpoint( + artifact_name, + file_name="jax_dqn_state.msgpack", + ) + if restored is not None: + checkpoint_path, metadata = restored + template = { + "params": train_state.params, + "target_params": target_params, + "opt_state": train_state.opt_state, + "global_step": 0, + "epsilon": epsilon_value, + } + payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) + train_state = train_state.replace( + params=payload["params"], + opt_state=payload["opt_state"], + ) + target_params = payload["target_params"] + start_step = int(payload.get("global_step", metadata.get("step", 0))) + start_step = max(0, min(start_step, int(run_cfg["total_timesteps"]))) + epsilon_value = float(payload.get("epsilon", epsilon_value)) + + @jax.jit + def dqn_update( + state: TrainState, + target: Any, + batch: ReplayBatch, + ) -> tuple[TrainState, jax.Array]: + def loss_fn(model_params): + q_values = q_net.apply(model_params, batch.obs) + chosen = jnp.take_along_axis( + q_values, + batch.actions[:, None], + axis=1, + ).squeeze(-1) + next_q = q_net.apply(target, batch.next_obs) + next_max = jnp.max(next_q, axis=1) + td_target = ( + batch.rewards + run_cfg["gamma"] * (1.0 - batch.dones) * next_max + ) + td_error = chosen - jax.lax.stop_gradient(td_target) + return jnp.mean(jnp.square(td_error)) + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + next_state = state.apply_gradients(grads=grads) + return next_state, loss + + metric_sums = { + "reward": 0.0, + "revenue": 0.0, + "agent_prob": 0.0, + "alpha_adv": 0.0, + "coi_leakage": 0.0, + "loss": 0.0, + } + metric_count = 0 + loss_count = 0 + + total_steps = int(run_cfg["total_timesteps"]) + checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"])) + batch_size = max(1, int(run_cfg["batch_size"])) + + for global_step in range(start_step + 1, total_steps + 1): + epsilon_value = _epsilon_by_fraction(global_step - 1, run_cfg) + + rng, eps_key, action_key, step_key, reset_key, sample_key = jax.random.split( + rng, 6 + ) + do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value)) + if do_explore: + action = jax.random.randint( + action_key, shape=(), minval=0, maxval=action_dim + ) + else: + q_values = q_net.apply(train_state.params, obs) + action = jnp.argmax(q_values) + + next_obs, next_state, reward, done, info = env.step(step_key, env_state, action) + buffer = _replay_add( + buffer, + obs, + action, + reward, + next_obs, + done.astype(jnp.float32), + ) + + metric_count += 1 + metric_sums["reward"] += _scalar(reward) + metric_sums["revenue"] += _scalar(info["revenue"]) + metric_sums["agent_prob"] += _scalar(info["agent_prob"]) + metric_sums["alpha_adv"] += _scalar(info["alpha_adv"]) + metric_sums["coi_leakage"] += _scalar(info["coi_leakage"]) + + if bool(np.asarray(done)): + obs, env_state = env.reset(reset_key) + else: + obs, env_state = next_obs, next_state + + ready = ( + global_step >= int(run_cfg["learning_starts"]) + and global_step % int(run_cfg["train_freq"]) == 0 + and _replay_size(buffer) >= batch_size + ) + if ready: + batch = _replay_sample(buffer, sample_key, batch_size) + train_state, loss = dqn_update(train_state, target_params, batch) + metric_sums["loss"] += _scalar(loss) + loss_count += 1 + + if global_step % int(run_cfg["target_update_interval"]) == 0: + target_params = train_state.params + + if ( + HAS_WANDB + and wandb.run is not None + and global_step % int(run_cfg["log_freq"]) == 0 + ): + wandb.log( + { + "train/reward": metric_sums["reward"] / max(metric_count, 1), + "train/revenue": metric_sums["revenue"] / max(metric_count, 1), + "train/agent_prob": metric_sums["agent_prob"] + / max(metric_count, 1), + "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), + "train/coi_leakage": metric_sums["coi_leakage"] + / max(metric_count, 1), + "train/dqn_loss": metric_sums["loss"] / max(loss_count, 1), + "train/epsilon": epsilon_value, + "train/global_step": global_step, + }, + step=global_step, + ) + + if artifact_name is not None and global_step % checkpoint_interval == 0: + payload = serialization.to_bytes( + { + "params": train_state.params, + "target_params": target_params, + "opt_state": train_state.opt_state, + "global_step": global_step, + "epsilon": epsilon_value, + } + ) + log_checkpoint_bytes( + artifact_name, + file_name="jax_dqn_state.msgpack", + payload=payload, + metadata={ + "step": global_step, + "algo": "dqn", + }, + ) + if _stop_requested.is_set(): + break + + denom = float(metric_count) if metric_count > 0 else 1.0 + metrics = { + "train/reward": float(metric_sums["reward"] / denom), + "train/revenue": float(metric_sums["revenue"] / denom), + "train/agent_prob": float(metric_sums["agent_prob"] / denom), + "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), + "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), + "train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)), + "train/global_step": total_steps, + } + + eval_metrics = _evaluate_q_network( + network=q_net, + params=train_state.params, + env=env, + episodes=run_cfg["eval_episodes"], + seed=run_cfg["seed"] + 7, + ) + metrics.update(eval_metrics) + model_dir = Path(run_cfg["model_dir"]) model_dir.mkdir(parents=True, exist_ok=True) - model_path = model_dir / "phantom_ppo_jax.msgpack" + model_path = model_dir / "phantom_dqn_jax.msgpack" model_path.write_bytes(serialization.to_bytes(train_state.params)) metrics["model/path"] = str(model_path) - return {"params": train_state.params}, metrics + return { + "params": train_state.params, + "target_params": target_params, + }, metrics + + +def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: + run_cfg = dict(cfg) + env = _make_env(run_cfg) + action_dim = env.action_space_n() + n_bins = max(2, int(run_cfg["q_bins"])) + n_products = int(run_cfg["n_products"]) + + q_table = jnp.zeros((n_bins, n_bins, n_bins, action_dim), dtype=jnp.float32) + demand_bins = jnp.linspace(0.0, 100.0, n_bins + 1, dtype=jnp.float32)[1:-1] + price_bins = jnp.linspace( + float(run_cfg["price_low"]), + float(run_cfg["price_high"]), + n_bins + 1, + dtype=jnp.float32, + )[1:-1] + + rng = jax.random.PRNGKey(run_cfg["seed"]) + rng, reset_key = jax.random.split(rng) + obs, env_state = env.reset(reset_key) + + epsilon_value = float(run_cfg["eps_start"]) + start_step = 0 + artifact_name = None + + if HAS_WANDB and wandb.run is not None: + sweep_id = getattr(wandb.run, "sweep_id", None) + artifact_name = checkpoint_artifact_name( + run_cfg, + backend="jax", + sweep_id=sweep_id, + ) + restored = download_latest_checkpoint( + artifact_name, + file_name="jax_qtable_state.msgpack", + ) + if restored is not None: + checkpoint_path, metadata = restored + template = { + "q_table": q_table, + "global_step": 0, + "epsilon": epsilon_value, + } + payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) + q_table = payload["q_table"] + start_step = int(payload.get("global_step", metadata.get("step", 0))) + start_step = max(0, min(start_step, int(run_cfg["total_timesteps"]))) + epsilon_value = float(payload.get("epsilon", epsilon_value)) + + metric_sums = { + "reward": 0.0, + "revenue": 0.0, + "agent_prob": 0.0, + "alpha_adv": 0.0, + "coi_leakage": 0.0, + } + metric_count = 0 + + total_steps = int(run_cfg["total_timesteps"]) + checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"])) + + for global_step in range(start_step + 1, total_steps + 1): + s0, s1, s2 = _encode_qtable_state( + obs, + n_products=n_products, + demand_bins=demand_bins, + price_bins=price_bins, + ) + state_q = q_table[s0, s1, s2] + + rng, eps_key, action_key, step_key, reset_key = jax.random.split(rng, 5) + do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value)) + if do_explore: + action = jax.random.randint( + action_key, shape=(), minval=0, maxval=action_dim + ) + else: + action = jnp.argmax(state_q) + + next_obs, next_state, reward, done, info = env.step(step_key, env_state, action) + ns0, ns1, ns2 = _encode_qtable_state( + next_obs, + n_products=n_products, + demand_bins=demand_bins, + price_bins=price_bins, + ) + + best_next = jnp.max(q_table[ns0, ns1, ns2]) + done_f = done.astype(jnp.float32) + td_target = reward + run_cfg["gamma"] * (1.0 - done_f) * best_next + old_value = q_table[s0, s1, s2, action] + new_value = old_value + run_cfg["q_lr"] * (td_target - old_value) + q_table = q_table.at[s0, s1, s2, action].set(new_value) + + epsilon_value = max( + float(run_cfg["eps_end"]), + epsilon_value * float(run_cfg["eps_decay"]), + ) + + metric_count += 1 + metric_sums["reward"] += _scalar(reward) + metric_sums["revenue"] += _scalar(info["revenue"]) + metric_sums["agent_prob"] += _scalar(info["agent_prob"]) + metric_sums["alpha_adv"] += _scalar(info["alpha_adv"]) + metric_sums["coi_leakage"] += _scalar(info["coi_leakage"]) + + if bool(np.asarray(done)): + obs, env_state = env.reset(reset_key) + else: + obs, env_state = next_obs, next_state + + if ( + HAS_WANDB + and wandb.run is not None + and global_step % int(run_cfg["log_freq"]) == 0 + ): + wandb.log( + { + "train/reward": metric_sums["reward"] / max(metric_count, 1), + "train/revenue": metric_sums["revenue"] / max(metric_count, 1), + "train/agent_prob": metric_sums["agent_prob"] + / max(metric_count, 1), + "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), + "train/coi_leakage": metric_sums["coi_leakage"] + / max(metric_count, 1), + "train/epsilon": epsilon_value, + "train/global_step": global_step, + }, + step=global_step, + ) + + if artifact_name is not None and global_step % checkpoint_interval == 0: + payload = serialization.to_bytes( + { + "q_table": q_table, + "global_step": global_step, + "epsilon": epsilon_value, + } + ) + log_checkpoint_bytes( + artifact_name, + file_name="jax_qtable_state.msgpack", + payload=payload, + metadata={ + "step": global_step, + "algo": "qtable", + }, + ) + + denom = float(metric_count) if metric_count > 0 else 1.0 + metrics = { + "train/reward": float(metric_sums["reward"] / denom), + "train/revenue": float(metric_sums["revenue"] / denom), + "train/agent_prob": float(metric_sums["agent_prob"] / denom), + "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), + "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), + "train/global_step": total_steps, + } + + eval_metrics = _evaluate_q_table( + q_table=q_table, + env=env, + episodes=run_cfg["eval_episodes"], + seed=run_cfg["seed"] + 7, + n_products=n_products, + demand_bins=demand_bins, + price_bins=price_bins, + ) + metrics.update(eval_metrics) + + model_dir = Path(run_cfg["model_dir"]) + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / "phantom_qtable_jax.msgpack" + model_path.write_bytes(serialization.to_bytes(q_table)) + metrics["model/path"] = str(model_path) + return {"q_table": q_table}, metrics + + +def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: + if not HAS_JAX_STACK: + raise ImportError( + "JAX path requires jax, flax, optax, and distrax. " + "Install engine/jax/requirements.txt on this machine first." + ) + + _init_jax_distributed() + _stop_requested.clear() + run_cfg = _jax_cfg(cfg) + algo = run_cfg["algo"] + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, lambda *_: _stop_requested.set()) + + if algo in {"ppo", "a2c"}: + return _train_actor_critic(run_cfg, algo=algo) + if algo == "dqn": + return _train_dqn(run_cfg) + if algo == "qtable": + return _train_qtable(run_cfg) + raise ValueError(f"Unsupported JAX algo '{algo}'") diff --git a/engine/lib/behavior.py b/engine/lib/behavior.py index e8fe2be..6a3a411 100644 --- a/engine/lib/behavior.py +++ b/engine/lib/behavior.py @@ -3,11 +3,16 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parents[2])) -from sim.rl.behavior_loader.models import ( - BehaviorModel, - AgentBehaviorModel, - aggregate_event_transitions, -) +try: + from sim.rl.behavior_loader.models import ( + BehaviorModel, + AgentBehaviorModel, + aggregate_event_transitions, + ) +except ImportError: + BehaviorModel = None + AgentBehaviorModel = None + aggregate_event_transitions = None import pandas as pd import numpy as np from .demand import generate_demand_for_actor @@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots def _get_base_pivot(human: bool): + if ( + BehaviorModel is None + or AgentBehaviorModel is None + or aggregate_event_transitions is None + ): + raise ImportError("behavior loader dependencies are unavailable") key = "human" if human else "agent" if key not in _cache: model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir) @@ -34,6 +45,13 @@ def get_transition_models(): returns: tuple: (human_transitions, agent_transitions) as dicts of event->event->prob """ + if ( + BehaviorModel is None + or AgentBehaviorModel is None + or aggregate_event_transitions is None + ): + raise ImportError("behavior loader dependencies are unavailable") + human_model = BehaviorModel(human_dir) agent_model = AgentBehaviorModel(agent_dir) diff --git a/engine/train.py b/engine/train.py index 35ca582..f6b256d 100644 --- a/engine/train.py +++ b/engine/train.py @@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict: "JAX backend requested but JAX is not installed. " "Install engine/jax/requirements.txt and jax[tpu] for TPU runs." ) - if algo == "qtable": - raise ValueError("qtable is not supported in JAX backend") try: from .jax.train import train_jax except Exception as exc: # pragma: no cover @@ -409,20 +407,25 @@ def run_wandb( init_kwargs = {"mode": mode} if sweep_mode: run = wandb.init(**init_kwargs) - cfg = _cfg(_wandb_cfg_dict()) - for k, v in overrides.items(): - if k not in wandb.config: - cfg[k] = v else: run = wandb.init(project=project, config=overrides, **init_kwargs) + + try: cfg = _cfg(_wandb_cfg_dict()) - metrics = train_once(cfg) - step = int(metrics.get("train/global_step", cfg["total_timesteps"])) - wandb.log(metrics, step=step) - for k, v in metrics.items(): - run.summary[k] = v - wandb.finish() - return metrics + if sweep_mode: + for k, v in overrides.items(): + if k not in wandb.config: + cfg[k] = v + + metrics = train_once(cfg) + step = int(metrics.get("train/global_step", cfg["total_timesteps"])) + wandb.log(metrics, step=step) + for k, v in metrics.items(): + run.summary[k] = v + return metrics + finally: + if wandb.run is not None: + wandb.finish() def run_local(overrides: dict) -> dict: diff --git a/paper/src/chapters/03-methodology.tex b/paper/src/chapters/03-methodology.tex index 19c5997..f20f59c 100644 --- a/paper/src/chapters/03-methodology.tex +++ b/paper/src/chapters/03-methodology.tex @@ -140,6 +140,7 @@ The architecture of this platform begins with the deployed web-apps posting inte \paragraph{Public Web Artifact} We transition the Kappa like architecture of the data collection to a Lambda system for actual learning in a surrogate environment. This allows us to move faster on data which is provided and helps us create a feedback loop for production deployment. To support further research in this intersection of fields we release P4P \footnote{\url{https://github.com/velocitatem/p4p}} as a public repository providing the interaction layer of the PHANTOM framework. This provides a configurable storefront which can be tailored to any commercial setting with a standardized session-level event tracking. We document the API adapters or what the framework expects in terms of schemas for pricing providers and log ingestion servicse. The repository is intended for controlled experimentation and method replication rather than production commerce deployment. + \subsubsection{DevOps Principles} Reproducible results are key to quality research platforms, this is taken into mind when deploying and working with our research platform. From a deployment standpoint the platform can be deployed across a large variety of providers and can be run locally. When developing a new interaction modality apart from the ones that come out of the box, a simple template pattern can be followed. The middleware of the framework is designed to properly render the chosen modality from environmental variables, thus deployment of different or parallel version of the software can be easily parametrized. @@ -235,7 +236,11 @@ v4 & 64 (32 + 32) & us-central2-b & 32 Spot + 32 On-demand \\ \end{tabular} \end{table} -For interactive monitoring from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}. +For connections from Madrid, we prioritize the europe-west4 allocation for latency-sensitive runs with the benefit of having the most grouped chips within a single region. This regional grouping is important for the deployment of our Kubernetes cluster which cannot span multiple regions. All sweep metadata, model checkpoints, and reward traces are logged in Weights \& Biases. Hardware specifications are from the official Google Cloud TPU documentation \parencite{noauthor_tpu_2026,noauthor_tpu_2025-1,noauthor_tpu_2025}. + +Design of training processes: we build docker image with the fact in mind of different caching over layers in order to most speed up docker re-building and such we place the most volatile steps towards the end of the image building. What is means in practice is that any dependency installations are isolated so edits to source code do no trigger rebuilds. Only if we update our entry point of training a sweep, Docker will also rebuild the source-code copy stage. + +Due to the preemptive nature of the current demand of TPU chips we sttle for running our on demeaned as the primary source of compute. The on demand TPU pod of 32 chips spread across 4 virtual hosts creates a relatively unique parallelization setup. Despite our desire to use a traditional approach of clustering and perhaps deploying SLURM jobs of our sweep agent, the lack of predictability in provisioning each instance of a compute resource makes this an high friction layer we do not want to add. \subsubsection{Interaction Schema}