From 9acc998cc90744c936f582e85db343ec0e3bcff4 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Tue, 17 Feb 2026 16:54:55 +0100 Subject: [PATCH] fixing models for gcp --- Makefile | 116 +++++++++- engine/jax/train.py | 460 ++++++++++++++++++++++++---------------- engine/lib/__init__.py | 2 +- engine/lib/callbacks.py | 63 ++++++ engine/train.py | 49 ++++- 5 files changed, 497 insertions(+), 193 deletions(-) diff --git a/Makefile b/Makefile index 27ce523..43c55ee 100644 --- a/Makefile +++ b/Makefile @@ -13,8 +13,8 @@ TPU_ZONE ?= us-central2-b TPU_TYPE ?= v4-32 TPU_RUNTIME ?= tpu-vm-v4-base TPU_PROJECT ?= phantom-trc -TPU_NETWORK ?= default -TPU_SUBNETWORK ?= default-us-central2 +TPU_NETWORK ?= tpu-network +TPU_SUBNETWORK ?= tpu-network TPU_USE_SPOT ?= 0 TPU_EXTRA_CREATE_FLAGS ?= TPU_WORKDIR ?= ~/PHANTOM @@ -22,7 +22,29 @@ TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000 TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html TPU_VENV ?= .venv-tpu -TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline +TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online +SWEEP_ID ?= +SWEEP_COUNT ?= 5 +QUEUE_SCRIPT ?= scripts/queue_sweep.sh +TPU_QUEUE_TYPE ?= +TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b +TPU_QUEUE_REUSE_EXISTING ?= 1 +TPU_QUEUE_KEEP_ALIVE ?= 1 +TPU_QUEUE_STRICT_QUOTA ?= 0 +TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1 +TPU_QUEUE_FILTER_ZONE ?= +TPU_QUEUE_FILTER_TYPE ?= +TPU_QUEUE_EXECUTION_MODE ?= venv +TPU_QUEUE_SYNC_METHOD ?= tar +TPU_QUEUE_SKIP_SYNC ?= 0 +TPU_QUEUE_DOCKER_IMAGE ?= +TPU_QUEUE_DOCKER_PULL ?= 1 +TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1 +TPU_QUEUE_SSH_BATCH_MODE ?= 1 +TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12 +TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine +TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1 +TPU_QUEUE_AUTO_SSH_ADD ?= 1 TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,) TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS) @@ -30,8 +52,13 @@ TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$ .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*" + @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*" @echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot" + @echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep" + @echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a" + @echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io//:tag" + @echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1" + @echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first" $(BUILDDIR): mkdir -p paper/$(BUILDDIR) @@ -104,11 +131,11 @@ tpu.check.zone: .PHONY: tpu.create.v4.ondemand tpu.create.v4.ondemand: - $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2 + $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network .PHONY: tpu.create.v4.spot tpu.create.v4.spot: - $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2 + $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network .PHONY: tpu.create tpu.create: tpu.check.zone @@ -179,6 +206,83 @@ tpu.bootstrap: tpu.ensure tpu.deploy tpu.install tpu.delete: gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet +.PHONY: tpu.queue.sweep +tpu.queue.sweep: + @set -e; \ + test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \ + test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \ + if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \ + if ! ssh-add -l >/dev/null 2>&1; then \ + if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \ + ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \ + fi; \ + fi; \ + AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)" + +.PHONY: tpu.queue.worker +tpu.queue.worker: + @set -e; \ + test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \ + test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \ + if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \ + if ! ssh-add -l >/dev/null 2>&1; then \ + if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \ + ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \ + fi; \ + fi; \ + AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)" + +.PHONY: tpu.queue.sweep.docker +tpu.queue.sweep.docker: + @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) + @$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker + +.PHONY: tpu.queue.worker.docker +tpu.queue.worker.docker: + @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) + @$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker + +.PHONY: tpu.queue.docker.build +tpu.queue.docker.build: + @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) + docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" . + +.PHONY: tpu.queue.docker.push +tpu.queue.docker.push: + @test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1) + docker push "$(TPU_QUEUE_DOCKER_IMAGE)" + +.PHONY: tpu.queue.status +tpu.queue.status: + @set -e; \ + if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \ + QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \ + else \ + QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \ + fi; \ + for ZONE in $(TPU_QUEUE_ZONES); do \ + echo "--- $$ZONE ---"; \ + if ! $$QCMD list --zone="$$ZONE"; then \ + echo "Skipping $$ZONE (unavailable or no permission)"; \ + fi; \ + done + +.PHONY: tpu.queue.clean +tpu.queue.clean: + @set -e; \ + if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \ + QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \ + else \ + QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \ + fi; \ + for ZONE in $(TPU_QUEUE_ZONES); do \ + $$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \ + case "$$NAME" in \ + qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \ + esac; \ + done; \ + done + .PHONY: stats.lines stats.lines: @find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \ diff --git a/engine/jax/train.py b/engine/jax/train.py index f2f4168..41678c1 100644 --- a/engine/jax/train.py +++ b/engine/jax/train.py @@ -7,6 +7,19 @@ from typing import Any, NamedTuple import numpy as np +try: + import wandb + + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + +from ..wandb_checkpoint import ( + checkpoint_artifact_name, + download_latest_checkpoint, + log_checkpoint_bytes, +) + try: import jax import jax.numpy as jnp @@ -142,6 +155,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: "num_minibatches": int(cfg.get("jax_num_minibatches", 4)), "update_epochs": int(cfg.get("jax_update_epochs", 4)), "anneal_lr": bool(cfg.get("jax_anneal_lr", True)), + "checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)), } rollout = out["num_envs"] * out["num_steps"] out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1)) @@ -185,201 +199,198 @@ def make_train(config: dict[str, Any]): frac = 1.0 - updates_done / max(cfg["num_updates"], 1) return cfg["learning_rate"] * frac - def train(rng: jax.Array): + if cfg["anneal_lr"]: + tx = optax.chain( + optax.clip_by_global_norm(cfg["max_grad_norm"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(cfg["max_grad_norm"]), + optax.adam(cfg["learning_rate"], eps=1e-5), + ) + + def init_runner_state(rng: jax.Array): rng, init_key = jax.random.split(rng) init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32) params = network.init(init_key, init_obs) - - if cfg["anneal_lr"]: - tx = optax.chain( - optax.clip_by_global_norm(cfg["max_grad_norm"]), - optax.adam(learning_rate=linear_schedule, eps=1e-5), - ) - else: - tx = optax.chain( - optax.clip_by_global_norm(cfg["max_grad_norm"]), - optax.adam(cfg["learning_rate"], eps=1e-5), - ) train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx) rng, reset_key = jax.random.split(rng) reset_keys = jax.random.split(reset_key, cfg["num_envs"]) obs, env_state = jax.vmap(env.reset)(reset_keys) + return train_state, env_state, obs, rng - def _update_step(runner_state, _): - def _env_step(runner_state, _): - train_state, env_state, last_obs, rng = runner_state - rng, action_key = jax.random.split(rng) - policy, value = network.apply(train_state.params, last_obs) - action = policy.sample(seed=action_key) - log_prob = policy.log_prob(action) - - rng, step_key = jax.random.split(rng) - step_keys = jax.random.split(step_key, cfg["num_envs"]) - nxt_obs, nxt_state, reward, done, info = jax.vmap( - env.step, - in_axes=(0, 0, 0), - )(step_keys, env_state, action) - - rng, reset_key = jax.random.split(rng) - reset_keys = jax.random.split(reset_key, cfg["num_envs"]) - rst_obs, rst_state = jax.vmap(env.reset)(reset_keys) - obs_next = jnp.where(done[:, None], rst_obs, nxt_obs) - env_next = jax.tree_util.tree_map( - lambda keep, reset: _select_env_state(done, keep, reset), - nxt_state, - rst_state, - ) - transition = Transition( - done=done, - action=action, - value=value, - reward=reward, - log_prob=log_prob, - obs=last_obs, - info=info, - ) - return (train_state, env_next, obs_next, rng), transition - - runner_state, traj_batch = jax.lax.scan( - _env_step, - runner_state, - None, - length=cfg["num_steps"], - ) - + def _update_step(runner_state, _): + def _env_step(runner_state, _): train_state, env_state, last_obs, rng = runner_state - _, last_value = network.apply(train_state.params, last_obs) + rng, action_key = jax.random.split(rng) + policy, value = network.apply(train_state.params, last_obs) + action = policy.sample(seed=action_key) + log_prob = policy.log_prob(action) - def _compute_gae(traj_batch, last_value): - def _gae_step(carry, transition): - gae, next_value = carry - delta = ( - transition.reward - + cfg["gamma"] * next_value * (1.0 - transition.done) - - transition.value - ) - gae = ( - delta - + cfg["gamma"] - * cfg["gae_lambda"] - * (1.0 - transition.done) - * gae - ) - return (gae, transition.value), gae + rng, step_key = jax.random.split(rng) + step_keys = jax.random.split(step_key, cfg["num_envs"]) + nxt_obs, nxt_state, reward, done, info = jax.vmap( + env.step, + in_axes=(0, 0, 0), + )(step_keys, env_state, action) - _, advantages = jax.lax.scan( - _gae_step, - (jnp.zeros_like(last_value), last_value), - traj_batch, - reverse=True, - unroll=16, - ) - targets = advantages + traj_batch.value - return advantages, targets - - advantages, targets = _compute_gae(traj_batch, last_value) - - def _update_epoch(update_state, _): - def _update_minibatch(train_state, batch_info): - traj_b, adv_b, tgt_b = batch_info - - def _loss_fn(params, traj_b, adv_b, tgt_b): - policy, value = network.apply(params, traj_b.obs) - log_prob = policy.log_prob(traj_b.action) - - value_clipped = traj_b.value + (value - traj_b.value).clip( - -cfg["clip_range"], cfg["clip_range"] - ) - value_loss = ( - 0.5 - * jnp.maximum( - jnp.square(value - tgt_b), - jnp.square(value_clipped - tgt_b), - ).mean() - ) - - adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8) - ratio = jnp.exp(log_prob - traj_b.log_prob) - loss_actor = -jnp.minimum( - ratio * adv_norm, - jnp.clip( - ratio, - 1.0 - cfg["clip_range"], - 1.0 + cfg["clip_range"], - ) - * adv_norm, - ).mean() - entropy = policy.entropy().mean() - total_loss = ( - loss_actor - + cfg["vf_coef"] * value_loss - - cfg["ent_coef"] * entropy - ) - return total_loss, (value_loss, loss_actor, entropy) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b) - train_state = train_state.apply_gradients(grads=grads) - return train_state, jnp.asarray(0.0, dtype=jnp.float32) - - train_state, traj_batch, advantages, targets, rng = update_state - rng, perm_key = jax.random.split(rng) - batch_size = cfg["num_envs"] * cfg["num_steps"] - permutation = jax.random.permutation(perm_key, batch_size) - batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), - batch, - ) - shuffled = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), - batch, - ) - minibatches = jax.tree_util.tree_map( - lambda x: x.reshape( - (cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:] - ), - shuffled, - ) - train_state, _ = jax.lax.scan( - _update_minibatch, train_state, minibatches - ) - return (train_state, traj_batch, advantages, targets, rng), None - - update_state = (train_state, traj_batch, advantages, targets, rng) - update_state, _ = jax.lax.scan( - _update_epoch, - update_state, - None, - length=cfg["update_epochs"], + rng, reset_key = jax.random.split(rng) + reset_keys = jax.random.split(reset_key, cfg["num_envs"]) + rst_obs, rst_state = jax.vmap(env.reset)(reset_keys) + obs_next = jnp.where(done[:, None], rst_obs, nxt_obs) + env_next = jax.tree_util.tree_map( + lambda keep, reset: _select_env_state(done, keep, reset), + nxt_state, + rst_state, ) - train_state = update_state[0] - rng = update_state[-1] + transition = Transition( + done=done, + action=action, + value=value, + reward=reward, + log_prob=log_prob, + obs=last_obs, + info=info, + ) + return (train_state, env_next, obs_next, rng), transition - metric = { - "reward": jnp.mean(traj_batch.reward), - "revenue": jnp.mean(traj_batch.info["revenue"]), - "agent_prob": jnp.mean(traj_batch.info["agent_prob"]), - "alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]), - "coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]), - } - runner_state = (train_state, env_state, last_obs, rng) - return runner_state, metric + runner_state, traj_batch = jax.lax.scan( + _env_step, + runner_state, + None, + length=cfg["num_steps"], + ) - runner_state = (train_state, env_state, obs, rng) + train_state, env_state, last_obs, rng = runner_state + _, last_value = network.apply(train_state.params, last_obs) + + def _compute_gae(traj_batch, last_value): + def _gae_step(carry, transition): + gae, next_value = carry + delta = ( + transition.reward + + cfg["gamma"] * next_value * (1.0 - transition.done) + - transition.value + ) + gae = ( + delta + + cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae + ) + return (gae, transition.value), gae + + _, advantages = jax.lax.scan( + _gae_step, + (jnp.zeros_like(last_value), last_value), + traj_batch, + reverse=True, + unroll=16, + ) + targets = advantages + traj_batch.value + return advantages, targets + + advantages, targets = _compute_gae(traj_batch, last_value) + + def _update_epoch(update_state, _): + def _update_minibatch(train_state, batch_info): + traj_b, adv_b, tgt_b = batch_info + + def _loss_fn(params, traj_b, adv_b, tgt_b): + policy, value = network.apply(params, traj_b.obs) + log_prob = policy.log_prob(traj_b.action) + + value_clipped = traj_b.value + (value - traj_b.value).clip( + -cfg["clip_range"], cfg["clip_range"] + ) + value_loss = ( + 0.5 + * jnp.maximum( + jnp.square(value - tgt_b), + jnp.square(value_clipped - tgt_b), + ).mean() + ) + + adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8) + ratio = jnp.exp(log_prob - traj_b.log_prob) + loss_actor = -jnp.minimum( + ratio * adv_norm, + jnp.clip( + ratio, + 1.0 - cfg["clip_range"], + 1.0 + cfg["clip_range"], + ) + * adv_norm, + ).mean() + entropy = policy.entropy().mean() + total_loss = ( + loss_actor + + cfg["vf_coef"] * value_loss + - cfg["ent_coef"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b) + train_state = train_state.apply_gradients(grads=grads) + return train_state, jnp.asarray(0.0, dtype=jnp.float32) + + train_state, traj_batch, advantages, targets, rng = update_state + rng, perm_key = jax.random.split(rng) + batch_size = cfg["num_envs"] * cfg["num_steps"] + permutation = jax.random.permutation(perm_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), + batch, + ) + shuffled = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), + batch, + ) + minibatches = jax.tree_util.tree_map( + lambda x: x.reshape( + (cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:] + ), + shuffled, + ) + train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches) + return (train_state, traj_batch, advantages, targets, rng), None + + update_state = (train_state, traj_batch, advantages, targets, rng) + update_state, _ = jax.lax.scan( + _update_epoch, + update_state, + None, + length=cfg["update_epochs"], + ) + train_state = update_state[0] + rng = update_state[-1] + + metric = { + "reward": jnp.mean(traj_batch.reward), + "revenue": jnp.mean(traj_batch.info["revenue"]), + "agent_prob": jnp.mean(traj_batch.info["agent_prob"]), + "alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]), + "coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]), + } + next_runner_state = (train_state, env_state, last_obs, rng) + return next_runner_state, metric + + def run_updates(runner_state, *, num_updates: int): + updates = max(1, int(num_updates)) runner_state, metric = jax.lax.scan( _update_step, runner_state, None, - length=cfg["num_updates"], + length=updates, ) return { "runner_state": runner_state, "metrics": metric, } - return train, network, env, cfg + return init_runner_state, run_updates, network, env, cfg def evaluate_policy( @@ -436,22 +447,103 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'" ) - train_fn, network, env, run_cfg = make_train(run_cfg) - train_jit = jax.jit(train_fn) - rng = jax.random.PRNGKey(run_cfg["seed"]) - out = train_jit(rng) + init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg) + run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",)) + rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"]) + total_updates = int(run_cfg["num_updates"]) + checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000))) + segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1)) - train_state = out["runner_state"][0] - metric = out["metrics"] + rng = jax.random.PRNGKey(run_cfg["seed"]) + runner_state = init_runner_state(rng) + updates_done = 0 + + artifact_name = None + if HAS_WANDB and wandb.run is not None: + sweep_id = getattr(wandb.run, "sweep_id", None) + artifact_name = checkpoint_artifact_name( + run_cfg, + backend="jax", + sweep_id=sweep_id, + ) + restored = download_latest_checkpoint( + artifact_name, + file_name="jax_runner_state.msgpack", + ) + if restored is not None: + checkpoint_path, metadata = restored + template = { + "runner_state": runner_state, + "updates_done": 0, + } + payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) + runner_state = payload["runner_state"] + updates_done = int(payload.get("updates_done", 0)) + if updates_done <= 0: + updates_done = int(metadata.get("updates_done", 0)) + updates_done = max(0, min(updates_done, total_updates)) + + metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"] + metric_sums = {k: 0.0 for k in metric_keys} + metric_count = 0 + + while updates_done < total_updates: + updates_this_segment = min(segment_updates, total_updates - updates_done) + out = run_updates_jit(runner_state, num_updates=updates_this_segment) + runner_state = out["runner_state"] + metric = out["metrics"] + + segment_values = { + k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys + } + segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0 + metric_count += segment_count + for key in metric_keys: + metric_sums[key] += float(segment_values[key].sum()) + + updates_done += int(updates_this_segment) + global_step = int(updates_done * rollout_steps) + + if HAS_WANDB and wandb.run is not None: + wandb.log( + { + "train/reward": float(segment_values["reward"].mean()), + "train/revenue": float(segment_values["revenue"].mean()), + "train/agent_prob": float(segment_values["agent_prob"].mean()), + "train/alpha_adv": float(segment_values["alpha_adv"].mean()), + "train/coi_leakage": float(segment_values["coi_leakage"].mean()), + "train/global_step": global_step, + }, + step=global_step, + ) + if artifact_name is not None: + checkpoint_payload = serialization.to_bytes( + { + "runner_state": runner_state, + "updates_done": updates_done, + } + ) + log_checkpoint_bytes( + artifact_name, + file_name="jax_runner_state.msgpack", + payload=checkpoint_payload, + metadata={ + "step": global_step, + "updates_done": updates_done, + "rollout_steps": rollout_steps, + "algo": "ppo", + }, + ) + + train_state = runner_state[0] + denom = float(metric_count) if metric_count > 0 else 1.0 metrics = { - "train/reward": float(np.mean(np.asarray(metric["reward"]))), - "train/revenue": float(np.mean(np.asarray(metric["revenue"]))), - "train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))), - "train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))), - "train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))), - "train/global_step": int( - run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"] - ), + "train/reward": float(metric_sums["reward"] / denom), + "train/revenue": float(metric_sums["revenue"] / denom), + "train/agent_prob": float(metric_sums["agent_prob"] / denom), + "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), + "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), + "train/global_step": int(updates_done * rollout_steps), } eval_metrics = evaluate_policy( diff --git a/engine/lib/__init__.py b/engine/lib/__init__.py index 874db63..c2fafc9 100644 --- a/engine/lib/__init__.py +++ b/engine/lib/__init__.py @@ -2,7 +2,7 @@ from .demand import estimate_demand, estimate_weighted_demand, generate_demand_f from .behavior import sample_behavior, get_transition_models, trajectory_to_events from .render import DashboardRenderer, style_axis from .wrappers import EconomicMetricsWrapper -from .callbacks import MetricsCallback, EvalMetricsCallback +from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback from .providers import ( ProviderBenchmark, ProviderResult, diff --git a/engine/lib/callbacks.py b/engine/lib/callbacks.py index 9e16d4b..05e77a0 100644 --- a/engine/lib/callbacks.py +++ b/engine/lib/callbacks.py @@ -1,8 +1,12 @@ """Training callbacks for W&B/TensorBoard logging - reads from info dict.""" +from pathlib import Path + from stable_baselines3.common.callbacks import BaseCallback, EvalCallback import numpy as np +from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file + try: import wandb @@ -80,6 +84,65 @@ class MetricsCallback(BaseCallback): self._episode_revenues = [] +class CheckpointArtifactCallback(BaseCallback): + """Periodic SB3 checkpoint uploader backed by W&B artifacts.""" + + def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0): + super().__init__(verbose) + self.cfg = dict(cfg) + self.interval = max(1, int(interval)) + self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models"))) + self.model_dir.mkdir(parents=True, exist_ok=True) + self._next_checkpoint = self.interval + self._last_saved_step = -1 + + def _artifact_name(self) -> str: + sweep_id = ( + getattr(wandb.run, "sweep_id", None) + if HAS_WANDB and wandb.run is not None + else None + ) + return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id) + + def _checkpoint_file(self) -> Path: + algo = str(self.cfg.get("algo", "model")) + base = self.model_dir / f"phantom_{algo}_checkpoint" + self.model.save(str(base)) + return base.with_suffix(".zip") + + def _save_checkpoint(self) -> None: + if not HAS_WANDB or wandb.run is None: + return + step = int(self.num_timesteps) + if step <= self._last_saved_step: + return + checkpoint_path = self._checkpoint_file() + metadata = { + "step": step, + "algo": str(self.cfg.get("algo", "unknown")), + "sweep_id": getattr(wandb.run, "sweep_id", None), + } + saved = log_checkpoint_file( + self._artifact_name(), + file_path=checkpoint_path, + artifact_file_name=checkpoint_path.name, + metadata=metadata, + ) + if saved: + self._last_saved_step = step + + def _on_step(self) -> bool: + if self.num_timesteps < self._next_checkpoint: + return True + self._save_checkpoint() + while self._next_checkpoint <= self.num_timesteps: + self._next_checkpoint += self.interval + return True + + def _on_training_end(self) -> None: + self._save_checkpoint() + + class EvalMetricsCallback(EvalCallback): """Deterministic evaluation - true performance without exploration noise.""" diff --git a/engine/train.py b/engine/train.py index 8e4eb07..35ca582 100644 --- a/engine/train.py +++ b/engine/train.py @@ -6,6 +6,8 @@ import os from pathlib import Path import numpy as np +from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint + try: import wandb @@ -78,6 +80,7 @@ DEFAULT_CFG = { "jax_num_minibatches": 4, "jax_update_epochs": 4, "jax_anneal_lr": True, + "checkpoint_interval": 10_000, } @@ -262,6 +265,16 @@ def build_model(cfg: dict, env): raise ValueError(f"unsupported algo '{algo}'") +def _sb3_model_cls(algo: str): + if algo == "ppo": + return PPO + if algo == "a2c": + return A2C + if algo == "dqn": + return DQN + raise ValueError(f"unsupported algo '{algo}'") + + def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: from .lib.discrete import EventQTable @@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: def train_sb3(cfg: dict) -> tuple[object, dict]: if not HAS_SB3: raise ImportError("stable-baselines3 is required for SB3 models") - from .lib.callbacks import MetricsCallback + from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback env = make_env(cfg) eval_env = make_env(cfg) env = Monitor(env) eval_env = Monitor(eval_env) model = build_model(cfg, env) + resume_step = 0 + if HAS_WANDB and wandb.run is not None: + sweep_id = getattr(wandb.run, "sweep_id", None) + artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id) + checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip" + restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file) + if restored is not None: + checkpoint_path, metadata = restored + model = _sb3_model_cls(cfg["algo"]).load( + checkpoint_path.as_posix(), env=env + ) + resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0))) + model.num_timesteps = max( + int(getattr(model, "num_timesteps", 0)), resume_step + ) + cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))] + cbs.append( + CheckpointArtifactCallback( + cfg, + interval=int(cfg.get("checkpoint_interval", 10_000)), + ) + ) cbs.append( EvalCallback( eval_env, @@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]: verbose=0, ) ) - model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs) + target_steps = int(cfg["total_timesteps"]) + remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0))) + if remaining_steps > 0: + model.learn( + total_timesteps=remaining_steps, + callback=cbs, + reset_num_timesteps=False, + ) + model_path = Path(cfg["model_dir"]) model_path.mkdir(parents=True, exist_ok=True) model.save(str(model_path / f"phantom_{cfg['algo']}")) @@ -413,6 +456,7 @@ def main(): p.add_argument("--jax-num-minibatches", type=int) p.add_argument("--jax-update-epochs", type=int) p.add_argument("--jax-anneal-lr", type=str) + p.add_argument("--checkpoint-interval", type=int) p.add_argument("--sweep-agent", action="store_true") p.add_argument("--sweep-id", type=str) p.add_argument("--count", type=int, default=0) @@ -441,6 +485,7 @@ def main(): "jax_num_steps": args.jax_num_steps, "jax_num_minibatches": args.jax_num_minibatches, "jax_update_epochs": args.jax_update_epochs, + "checkpoint_interval": args.checkpoint_interval, "jax_anneal_lr": _truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None,