mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
fixing models for gcp
This commit is contained in:
116
Makefile
116
Makefile
@@ -13,8 +13,8 @@ TPU_ZONE ?= us-central2-b
|
|||||||
TPU_TYPE ?= v4-32
|
TPU_TYPE ?= v4-32
|
||||||
TPU_RUNTIME ?= tpu-vm-v4-base
|
TPU_RUNTIME ?= tpu-vm-v4-base
|
||||||
TPU_PROJECT ?= phantom-trc
|
TPU_PROJECT ?= phantom-trc
|
||||||
TPU_NETWORK ?= default
|
TPU_NETWORK ?= tpu-network
|
||||||
TPU_SUBNETWORK ?= default-us-central2
|
TPU_SUBNETWORK ?= tpu-network
|
||||||
TPU_USE_SPOT ?= 0
|
TPU_USE_SPOT ?= 0
|
||||||
TPU_EXTRA_CREATE_FLAGS ?=
|
TPU_EXTRA_CREATE_FLAGS ?=
|
||||||
TPU_WORKDIR ?= ~/PHANTOM
|
TPU_WORKDIR ?= ~/PHANTOM
|
||||||
@@ -22,7 +22,29 @@ TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
|
|||||||
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
|
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
|
||||||
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
TPU_VENV ?= .venv-tpu
|
TPU_VENV ?= .venv-tpu
|
||||||
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline
|
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online
|
||||||
|
SWEEP_ID ?=
|
||||||
|
SWEEP_COUNT ?= 5
|
||||||
|
QUEUE_SCRIPT ?= scripts/queue_sweep.sh
|
||||||
|
TPU_QUEUE_TYPE ?=
|
||||||
|
TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b
|
||||||
|
TPU_QUEUE_REUSE_EXISTING ?= 1
|
||||||
|
TPU_QUEUE_KEEP_ALIVE ?= 1
|
||||||
|
TPU_QUEUE_STRICT_QUOTA ?= 0
|
||||||
|
TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1
|
||||||
|
TPU_QUEUE_FILTER_ZONE ?=
|
||||||
|
TPU_QUEUE_FILTER_TYPE ?=
|
||||||
|
TPU_QUEUE_EXECUTION_MODE ?= venv
|
||||||
|
TPU_QUEUE_SYNC_METHOD ?= tar
|
||||||
|
TPU_QUEUE_SKIP_SYNC ?= 0
|
||||||
|
TPU_QUEUE_DOCKER_IMAGE ?=
|
||||||
|
TPU_QUEUE_DOCKER_PULL ?= 1
|
||||||
|
TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1
|
||||||
|
TPU_QUEUE_SSH_BATCH_MODE ?= 1
|
||||||
|
TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12
|
||||||
|
TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine
|
||||||
|
TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1
|
||||||
|
TPU_QUEUE_AUTO_SSH_ADD ?= 1
|
||||||
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
|
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
|
||||||
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
|
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
|
||||||
|
|
||||||
@@ -30,8 +52,13 @@ TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$
|
|||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*"
|
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*"
|
||||||
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
|
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
|
||||||
|
@echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep"
|
||||||
|
@echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a"
|
||||||
|
@echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io/<project>/<image>:tag"
|
||||||
|
@echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1"
|
||||||
|
@echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first"
|
||||||
|
|
||||||
$(BUILDDIR):
|
$(BUILDDIR):
|
||||||
mkdir -p paper/$(BUILDDIR)
|
mkdir -p paper/$(BUILDDIR)
|
||||||
@@ -104,11 +131,11 @@ tpu.check.zone:
|
|||||||
|
|
||||||
.PHONY: tpu.create.v4.ondemand
|
.PHONY: tpu.create.v4.ondemand
|
||||||
tpu.create.v4.ondemand:
|
tpu.create.v4.ondemand:
|
||||||
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2
|
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network
|
||||||
|
|
||||||
.PHONY: tpu.create.v4.spot
|
.PHONY: tpu.create.v4.spot
|
||||||
tpu.create.v4.spot:
|
tpu.create.v4.spot:
|
||||||
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2
|
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network
|
||||||
|
|
||||||
.PHONY: tpu.create
|
.PHONY: tpu.create
|
||||||
tpu.create: tpu.check.zone
|
tpu.create: tpu.check.zone
|
||||||
@@ -179,6 +206,83 @@ tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
|
|||||||
tpu.delete:
|
tpu.delete:
|
||||||
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.sweep
|
||||||
|
tpu.queue.sweep:
|
||||||
|
@set -e; \
|
||||||
|
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
|
||||||
|
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
|
||||||
|
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
|
||||||
|
if ! ssh-add -l >/dev/null 2>&1; then \
|
||||||
|
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
|
||||||
|
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
|
||||||
|
fi; \
|
||||||
|
fi; \
|
||||||
|
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.worker
|
||||||
|
tpu.queue.worker:
|
||||||
|
@set -e; \
|
||||||
|
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
|
||||||
|
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
|
||||||
|
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
|
||||||
|
if ! ssh-add -l >/dev/null 2>&1; then \
|
||||||
|
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
|
||||||
|
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
|
||||||
|
fi; \
|
||||||
|
fi; \
|
||||||
|
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.sweep.docker
|
||||||
|
tpu.queue.sweep.docker:
|
||||||
|
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
||||||
|
@$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.worker.docker
|
||||||
|
tpu.queue.worker.docker:
|
||||||
|
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
||||||
|
@$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.docker.build
|
||||||
|
tpu.queue.docker.build:
|
||||||
|
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
||||||
|
docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" .
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.docker.push
|
||||||
|
tpu.queue.docker.push:
|
||||||
|
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
|
||||||
|
docker push "$(TPU_QUEUE_DOCKER_IMAGE)"
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.status
|
||||||
|
tpu.queue.status:
|
||||||
|
@set -e; \
|
||||||
|
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
|
||||||
|
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
|
||||||
|
else \
|
||||||
|
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
|
||||||
|
fi; \
|
||||||
|
for ZONE in $(TPU_QUEUE_ZONES); do \
|
||||||
|
echo "--- $$ZONE ---"; \
|
||||||
|
if ! $$QCMD list --zone="$$ZONE"; then \
|
||||||
|
echo "Skipping $$ZONE (unavailable or no permission)"; \
|
||||||
|
fi; \
|
||||||
|
done
|
||||||
|
|
||||||
|
.PHONY: tpu.queue.clean
|
||||||
|
tpu.queue.clean:
|
||||||
|
@set -e; \
|
||||||
|
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
|
||||||
|
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
|
||||||
|
else \
|
||||||
|
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
|
||||||
|
fi; \
|
||||||
|
for ZONE in $(TPU_QUEUE_ZONES); do \
|
||||||
|
$$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \
|
||||||
|
case "$$NAME" in \
|
||||||
|
qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \
|
||||||
|
esac; \
|
||||||
|
done; \
|
||||||
|
done
|
||||||
|
|
||||||
.PHONY: stats.lines
|
.PHONY: stats.lines
|
||||||
stats.lines:
|
stats.lines:
|
||||||
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \
|
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \
|
||||||
|
|||||||
@@ -7,6 +7,19 @@ from typing import Any, NamedTuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
HAS_WANDB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_WANDB = False
|
||||||
|
|
||||||
|
from ..wandb_checkpoint import (
|
||||||
|
checkpoint_artifact_name,
|
||||||
|
download_latest_checkpoint,
|
||||||
|
log_checkpoint_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -142,6 +155,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
|||||||
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
||||||
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
||||||
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
||||||
|
"checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)),
|
||||||
}
|
}
|
||||||
rollout = out["num_envs"] * out["num_steps"]
|
rollout = out["num_envs"] * out["num_steps"]
|
||||||
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
||||||
@@ -185,11 +199,6 @@ def make_train(config: dict[str, Any]):
|
|||||||
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
||||||
return cfg["learning_rate"] * frac
|
return cfg["learning_rate"] * frac
|
||||||
|
|
||||||
def train(rng: jax.Array):
|
|
||||||
rng, init_key = jax.random.split(rng)
|
|
||||||
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
|
||||||
params = network.init(init_key, init_obs)
|
|
||||||
|
|
||||||
if cfg["anneal_lr"]:
|
if cfg["anneal_lr"]:
|
||||||
tx = optax.chain(
|
tx = optax.chain(
|
||||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||||
@@ -200,11 +209,17 @@ def make_train(config: dict[str, Any]):
|
|||||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_runner_state(rng: jax.Array):
|
||||||
|
rng, init_key = jax.random.split(rng)
|
||||||
|
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
||||||
|
params = network.init(init_key, init_obs)
|
||||||
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
||||||
|
|
||||||
rng, reset_key = jax.random.split(rng)
|
rng, reset_key = jax.random.split(rng)
|
||||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||||
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
||||||
|
return train_state, env_state, obs, rng
|
||||||
|
|
||||||
def _update_step(runner_state, _):
|
def _update_step(runner_state, _):
|
||||||
def _env_step(runner_state, _):
|
def _env_step(runner_state, _):
|
||||||
@@ -261,10 +276,7 @@ def make_train(config: dict[str, Any]):
|
|||||||
)
|
)
|
||||||
gae = (
|
gae = (
|
||||||
delta
|
delta
|
||||||
+ cfg["gamma"]
|
+ cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae
|
||||||
* cfg["gae_lambda"]
|
|
||||||
* (1.0 - transition.done)
|
|
||||||
* gae
|
|
||||||
)
|
)
|
||||||
return (gae, transition.value), gae
|
return (gae, transition.value), gae
|
||||||
|
|
||||||
@@ -342,9 +354,7 @@ def make_train(config: dict[str, Any]):
|
|||||||
),
|
),
|
||||||
shuffled,
|
shuffled,
|
||||||
)
|
)
|
||||||
train_state, _ = jax.lax.scan(
|
train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches)
|
||||||
_update_minibatch, train_state, minibatches
|
|
||||||
)
|
|
||||||
return (train_state, traj_batch, advantages, targets, rng), None
|
return (train_state, traj_batch, advantages, targets, rng), None
|
||||||
|
|
||||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||||
@@ -364,22 +374,23 @@ def make_train(config: dict[str, Any]):
|
|||||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||||
}
|
}
|
||||||
runner_state = (train_state, env_state, last_obs, rng)
|
next_runner_state = (train_state, env_state, last_obs, rng)
|
||||||
return runner_state, metric
|
return next_runner_state, metric
|
||||||
|
|
||||||
runner_state = (train_state, env_state, obs, rng)
|
def run_updates(runner_state, *, num_updates: int):
|
||||||
|
updates = max(1, int(num_updates))
|
||||||
runner_state, metric = jax.lax.scan(
|
runner_state, metric = jax.lax.scan(
|
||||||
_update_step,
|
_update_step,
|
||||||
runner_state,
|
runner_state,
|
||||||
None,
|
None,
|
||||||
length=cfg["num_updates"],
|
length=updates,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"runner_state": runner_state,
|
"runner_state": runner_state,
|
||||||
"metrics": metric,
|
"metrics": metric,
|
||||||
}
|
}
|
||||||
|
|
||||||
return train, network, env, cfg
|
return init_runner_state, run_updates, network, env, cfg
|
||||||
|
|
||||||
|
|
||||||
def evaluate_policy(
|
def evaluate_policy(
|
||||||
@@ -436,22 +447,103 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
|||||||
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_fn, network, env, run_cfg = make_train(run_cfg)
|
init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg)
|
||||||
train_jit = jax.jit(train_fn)
|
run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",))
|
||||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
||||||
out = train_jit(rng)
|
total_updates = int(run_cfg["num_updates"])
|
||||||
|
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
||||||
|
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
||||||
|
|
||||||
train_state = out["runner_state"][0]
|
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||||
|
runner_state = init_runner_state(rng)
|
||||||
|
updates_done = 0
|
||||||
|
|
||||||
|
artifact_name = None
|
||||||
|
if HAS_WANDB and wandb.run is not None:
|
||||||
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||||
|
artifact_name = checkpoint_artifact_name(
|
||||||
|
run_cfg,
|
||||||
|
backend="jax",
|
||||||
|
sweep_id=sweep_id,
|
||||||
|
)
|
||||||
|
restored = download_latest_checkpoint(
|
||||||
|
artifact_name,
|
||||||
|
file_name="jax_runner_state.msgpack",
|
||||||
|
)
|
||||||
|
if restored is not None:
|
||||||
|
checkpoint_path, metadata = restored
|
||||||
|
template = {
|
||||||
|
"runner_state": runner_state,
|
||||||
|
"updates_done": 0,
|
||||||
|
}
|
||||||
|
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
||||||
|
runner_state = payload["runner_state"]
|
||||||
|
updates_done = int(payload.get("updates_done", 0))
|
||||||
|
if updates_done <= 0:
|
||||||
|
updates_done = int(metadata.get("updates_done", 0))
|
||||||
|
updates_done = max(0, min(updates_done, total_updates))
|
||||||
|
|
||||||
|
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
||||||
|
metric_sums = {k: 0.0 for k in metric_keys}
|
||||||
|
metric_count = 0
|
||||||
|
|
||||||
|
while updates_done < total_updates:
|
||||||
|
updates_this_segment = min(segment_updates, total_updates - updates_done)
|
||||||
|
out = run_updates_jit(runner_state, num_updates=updates_this_segment)
|
||||||
|
runner_state = out["runner_state"]
|
||||||
metric = out["metrics"]
|
metric = out["metrics"]
|
||||||
|
|
||||||
|
segment_values = {
|
||||||
|
k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys
|
||||||
|
}
|
||||||
|
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
||||||
|
metric_count += segment_count
|
||||||
|
for key in metric_keys:
|
||||||
|
metric_sums[key] += float(segment_values[key].sum())
|
||||||
|
|
||||||
|
updates_done += int(updates_this_segment)
|
||||||
|
global_step = int(updates_done * rollout_steps)
|
||||||
|
|
||||||
|
if HAS_WANDB and wandb.run is not None:
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/reward": float(segment_values["reward"].mean()),
|
||||||
|
"train/revenue": float(segment_values["revenue"].mean()),
|
||||||
|
"train/agent_prob": float(segment_values["agent_prob"].mean()),
|
||||||
|
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
|
||||||
|
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
|
||||||
|
"train/global_step": global_step,
|
||||||
|
},
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
|
if artifact_name is not None:
|
||||||
|
checkpoint_payload = serialization.to_bytes(
|
||||||
|
{
|
||||||
|
"runner_state": runner_state,
|
||||||
|
"updates_done": updates_done,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
log_checkpoint_bytes(
|
||||||
|
artifact_name,
|
||||||
|
file_name="jax_runner_state.msgpack",
|
||||||
|
payload=checkpoint_payload,
|
||||||
|
metadata={
|
||||||
|
"step": global_step,
|
||||||
|
"updates_done": updates_done,
|
||||||
|
"rollout_steps": rollout_steps,
|
||||||
|
"algo": "ppo",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
train_state = runner_state[0]
|
||||||
|
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||||
metrics = {
|
metrics = {
|
||||||
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
|
"train/reward": float(metric_sums["reward"] / denom),
|
||||||
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
|
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||||
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
|
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||||
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
|
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||||
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
|
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||||
"train/global_step": int(
|
"train/global_step": int(updates_done * rollout_steps),
|
||||||
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_metrics = evaluate_policy(
|
eval_metrics = evaluate_policy(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from .demand import estimate_demand, estimate_weighted_demand, generate_demand_f
|
|||||||
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
|
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
|
||||||
from .render import DashboardRenderer, style_axis
|
from .render import DashboardRenderer, style_axis
|
||||||
from .wrappers import EconomicMetricsWrapper
|
from .wrappers import EconomicMetricsWrapper
|
||||||
from .callbacks import MetricsCallback, EvalMetricsCallback
|
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
|
||||||
from .providers import (
|
from .providers import (
|
||||||
ProviderBenchmark,
|
ProviderBenchmark,
|
||||||
ProviderResult,
|
ProviderResult,
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
@@ -80,6 +84,65 @@ class MetricsCallback(BaseCallback):
|
|||||||
self._episode_revenues = []
|
self._episode_revenues = []
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointArtifactCallback(BaseCallback):
|
||||||
|
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
|
||||||
|
|
||||||
|
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
|
||||||
|
super().__init__(verbose)
|
||||||
|
self.cfg = dict(cfg)
|
||||||
|
self.interval = max(1, int(interval))
|
||||||
|
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
|
||||||
|
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._next_checkpoint = self.interval
|
||||||
|
self._last_saved_step = -1
|
||||||
|
|
||||||
|
def _artifact_name(self) -> str:
|
||||||
|
sweep_id = (
|
||||||
|
getattr(wandb.run, "sweep_id", None)
|
||||||
|
if HAS_WANDB and wandb.run is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
|
||||||
|
|
||||||
|
def _checkpoint_file(self) -> Path:
|
||||||
|
algo = str(self.cfg.get("algo", "model"))
|
||||||
|
base = self.model_dir / f"phantom_{algo}_checkpoint"
|
||||||
|
self.model.save(str(base))
|
||||||
|
return base.with_suffix(".zip")
|
||||||
|
|
||||||
|
def _save_checkpoint(self) -> None:
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return
|
||||||
|
step = int(self.num_timesteps)
|
||||||
|
if step <= self._last_saved_step:
|
||||||
|
return
|
||||||
|
checkpoint_path = self._checkpoint_file()
|
||||||
|
metadata = {
|
||||||
|
"step": step,
|
||||||
|
"algo": str(self.cfg.get("algo", "unknown")),
|
||||||
|
"sweep_id": getattr(wandb.run, "sweep_id", None),
|
||||||
|
}
|
||||||
|
saved = log_checkpoint_file(
|
||||||
|
self._artifact_name(),
|
||||||
|
file_path=checkpoint_path,
|
||||||
|
artifact_file_name=checkpoint_path.name,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
if saved:
|
||||||
|
self._last_saved_step = step
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
if self.num_timesteps < self._next_checkpoint:
|
||||||
|
return True
|
||||||
|
self._save_checkpoint()
|
||||||
|
while self._next_checkpoint <= self.num_timesteps:
|
||||||
|
self._next_checkpoint += self.interval
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _on_training_end(self) -> None:
|
||||||
|
self._save_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
class EvalMetricsCallback(EvalCallback):
|
class EvalMetricsCallback(EvalCallback):
|
||||||
"""Deterministic evaluation - true performance without exploration noise."""
|
"""Deterministic evaluation - true performance without exploration noise."""
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
@@ -78,6 +80,7 @@ DEFAULT_CFG = {
|
|||||||
"jax_num_minibatches": 4,
|
"jax_num_minibatches": 4,
|
||||||
"jax_update_epochs": 4,
|
"jax_update_epochs": 4,
|
||||||
"jax_anneal_lr": True,
|
"jax_anneal_lr": True,
|
||||||
|
"checkpoint_interval": 10_000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -262,6 +265,16 @@ def build_model(cfg: dict, env):
|
|||||||
raise ValueError(f"unsupported algo '{algo}'")
|
raise ValueError(f"unsupported algo '{algo}'")
|
||||||
|
|
||||||
|
|
||||||
|
def _sb3_model_cls(algo: str):
|
||||||
|
if algo == "ppo":
|
||||||
|
return PPO
|
||||||
|
if algo == "a2c":
|
||||||
|
return A2C
|
||||||
|
if algo == "dqn":
|
||||||
|
return DQN
|
||||||
|
raise ValueError(f"unsupported algo '{algo}'")
|
||||||
|
|
||||||
|
|
||||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||||
from .lib.discrete import EventQTable
|
from .lib.discrete import EventQTable
|
||||||
|
|
||||||
@@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
|||||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||||
if not HAS_SB3:
|
if not HAS_SB3:
|
||||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||||
from .lib.callbacks import MetricsCallback
|
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
env = Monitor(env)
|
env = Monitor(env)
|
||||||
eval_env = Monitor(eval_env)
|
eval_env = Monitor(eval_env)
|
||||||
model = build_model(cfg, env)
|
model = build_model(cfg, env)
|
||||||
|
resume_step = 0
|
||||||
|
if HAS_WANDB and wandb.run is not None:
|
||||||
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||||
|
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||||
|
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||||
|
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||||
|
if restored is not None:
|
||||||
|
checkpoint_path, metadata = restored
|
||||||
|
model = _sb3_model_cls(cfg["algo"]).load(
|
||||||
|
checkpoint_path.as_posix(), env=env
|
||||||
|
)
|
||||||
|
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
|
||||||
|
model.num_timesteps = max(
|
||||||
|
int(getattr(model, "num_timesteps", 0)), resume_step
|
||||||
|
)
|
||||||
|
|
||||||
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
|
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
|
||||||
|
cbs.append(
|
||||||
|
CheckpointArtifactCallback(
|
||||||
|
cfg,
|
||||||
|
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||||
|
)
|
||||||
|
)
|
||||||
cbs.append(
|
cbs.append(
|
||||||
EvalCallback(
|
EvalCallback(
|
||||||
eval_env,
|
eval_env,
|
||||||
@@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
|
|||||||
verbose=0,
|
verbose=0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs)
|
target_steps = int(cfg["total_timesteps"])
|
||||||
|
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||||
|
if remaining_steps > 0:
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=remaining_steps,
|
||||||
|
callback=cbs,
|
||||||
|
reset_num_timesteps=False,
|
||||||
|
)
|
||||||
|
|
||||||
model_path = Path(cfg["model_dir"])
|
model_path = Path(cfg["model_dir"])
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
model.save(str(model_path / f"phantom_{cfg['algo']}"))
|
model.save(str(model_path / f"phantom_{cfg['algo']}"))
|
||||||
@@ -413,6 +456,7 @@ def main():
|
|||||||
p.add_argument("--jax-num-minibatches", type=int)
|
p.add_argument("--jax-num-minibatches", type=int)
|
||||||
p.add_argument("--jax-update-epochs", type=int)
|
p.add_argument("--jax-update-epochs", type=int)
|
||||||
p.add_argument("--jax-anneal-lr", type=str)
|
p.add_argument("--jax-anneal-lr", type=str)
|
||||||
|
p.add_argument("--checkpoint-interval", type=int)
|
||||||
p.add_argument("--sweep-agent", action="store_true")
|
p.add_argument("--sweep-agent", action="store_true")
|
||||||
p.add_argument("--sweep-id", type=str)
|
p.add_argument("--sweep-id", type=str)
|
||||||
p.add_argument("--count", type=int, default=0)
|
p.add_argument("--count", type=int, default=0)
|
||||||
@@ -441,6 +485,7 @@ def main():
|
|||||||
"jax_num_steps": args.jax_num_steps,
|
"jax_num_steps": args.jax_num_steps,
|
||||||
"jax_num_minibatches": args.jax_num_minibatches,
|
"jax_num_minibatches": args.jax_num_minibatches,
|
||||||
"jax_update_epochs": args.jax_update_epochs,
|
"jax_update_epochs": args.jax_update_epochs,
|
||||||
|
"checkpoint_interval": args.checkpoint_interval,
|
||||||
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||||
if args.jax_anneal_lr is not None
|
if args.jax_anneal_lr is not None
|
||||||
else None,
|
else None,
|
||||||
|
|||||||
Reference in New Issue
Block a user