fixing models for gcp

This commit is contained in:
2026-02-17 16:54:55 +01:00
parent 802f31b4a1
commit 9acc998cc9
5 changed files with 497 additions and 193 deletions

116
Makefile
View File

@@ -13,8 +13,8 @@ TPU_ZONE ?= us-central2-b
TPU_TYPE ?= v4-32
TPU_RUNTIME ?= tpu-vm-v4-base
TPU_PROJECT ?= phantom-trc
TPU_NETWORK ?= default
TPU_SUBNETWORK ?= default-us-central2
TPU_NETWORK ?= tpu-network
TPU_SUBNETWORK ?= tpu-network
TPU_USE_SPOT ?= 0
TPU_EXTRA_CREATE_FLAGS ?=
TPU_WORKDIR ?= ~/PHANTOM
@@ -22,7 +22,29 @@ TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
TPU_VENV ?= .venv-tpu
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=online
SWEEP_ID ?=
SWEEP_COUNT ?= 5
QUEUE_SCRIPT ?= scripts/queue_sweep.sh
TPU_QUEUE_TYPE ?=
TPU_QUEUE_ZONES ?= europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b
TPU_QUEUE_REUSE_EXISTING ?= 1
TPU_QUEUE_KEEP_ALIVE ?= 1
TPU_QUEUE_STRICT_QUOTA ?= 0
TPU_QUEUE_DOWNSHIFT_ON_QUOTA ?= 1
TPU_QUEUE_FILTER_ZONE ?=
TPU_QUEUE_FILTER_TYPE ?=
TPU_QUEUE_EXECUTION_MODE ?= venv
TPU_QUEUE_SYNC_METHOD ?= tar
TPU_QUEUE_SKIP_SYNC ?= 0
TPU_QUEUE_DOCKER_IMAGE ?=
TPU_QUEUE_DOCKER_PULL ?= 1
TPU_QUEUE_DOCKER_AUTO_INSTALL ?= 1
TPU_QUEUE_SSH_BATCH_MODE ?= 1
TPU_QUEUE_SSH_CONNECT_TIMEOUT ?= 12
TPU_QUEUE_SSH_KEY_FILE ?= $(HOME)/.ssh/google_compute_engine
TPU_QUEUE_REQUIRE_SSH_AGENT ?= 1
TPU_QUEUE_AUTO_SSH_ADD ?= 1
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
@@ -30,8 +52,13 @@ TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$
.PHONY: help
help:
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*"
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.* | tpu.queue.*"
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
@echo "Queued sweep: SWEEP_ID=entity/project/id make tpu.queue.sweep"
@echo "Queued sweep filters: TPU_QUEUE_FILTER_TYPE=v6e TPU_QUEUE_FILTER_ZONE=europe-west4-a"
@echo "Docker queue: make tpu.queue.sweep.docker TPU_QUEUE_DOCKER_IMAGE=gcr.io/<project>/<image>:tag"
@echo "Docker queue without sync: add TPU_QUEUE_SKIP_SYNC=1"
@echo "If SSH key is encrypted: run ssh-add ~/.ssh/google_compute_engine first"
$(BUILDDIR):
mkdir -p paper/$(BUILDDIR)
@@ -104,11 +131,11 @@ tpu.check.zone:
.PHONY: tpu.create.v4.ondemand
tpu.create.v4.ondemand:
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=tpu-network
.PHONY: tpu.create.v4.spot
tpu.create.v4.spot:
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=tpu-network
.PHONY: tpu.create
tpu.create: tpu.check.zone
@@ -179,6 +206,83 @@ tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
tpu.delete:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
.PHONY: tpu.queue.sweep
tpu.queue.sweep:
@set -e; \
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
if ! ssh-add -l >/dev/null 2>&1; then \
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
fi; \
fi; \
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_QUEUE_FILTER_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_FILTER_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
.PHONY: tpu.queue.worker
tpu.queue.worker:
@set -e; \
test -n "$(SWEEP_ID)" || (echo "SWEEP_ID is required, e.g. SWEEP_ID=entity/project/id" && exit 1); \
test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY is required in your shell" && exit 1); \
if [ "$(TPU_QUEUE_AUTO_SSH_ADD)" = "1" ] && [ "$(TPU_QUEUE_SSH_BATCH_MODE)" != "0" ] && command -v ssh-add >/dev/null 2>&1 && [ -f "$(TPU_QUEUE_SSH_KEY_FILE)" ]; then \
if ! ssh-add -l >/dev/null 2>&1; then \
if [ -z "$$SSH_AUTH_SOCK" ] && command -v ssh-agent >/dev/null 2>&1; then eval "$$(ssh-agent -s)" >/dev/null; fi; \
ssh-add "$(TPU_QUEUE_SSH_KEY_FILE)"; \
fi; \
fi; \
AGENT_COUNT="$(SWEEP_COUNT)" PROJECT_ID="$(TPU_PROJECT)" TPU_NETWORK="$(TPU_NETWORK)" TPU_SUBNETWORK="$(TPU_SUBNETWORK)" TPU_REUSE_EXISTING="$(TPU_QUEUE_REUSE_EXISTING)" TPU_KEEP_ALIVE="$(TPU_QUEUE_KEEP_ALIVE)" TPU_STRICT_QUOTA="$(TPU_QUEUE_STRICT_QUOTA)" TPU_DOWNSHIFT_ON_QUOTA="$(TPU_QUEUE_DOWNSHIFT_ON_QUOTA)" TPU_EXECUTION_MODE="$(TPU_QUEUE_EXECUTION_MODE)" TPU_SYNC_METHOD="$(TPU_QUEUE_SYNC_METHOD)" TPU_SKIP_SYNC="$(TPU_QUEUE_SKIP_SYNC)" TPU_DOCKER_IMAGE="$(TPU_QUEUE_DOCKER_IMAGE)" TPU_DOCKER_PULL="$(TPU_QUEUE_DOCKER_PULL)" TPU_DOCKER_AUTO_INSTALL="$(TPU_QUEUE_DOCKER_AUTO_INSTALL)" TPU_SSH_BATCH_MODE="$(TPU_QUEUE_SSH_BATCH_MODE)" TPU_SSH_CONNECT_TIMEOUT="$(TPU_QUEUE_SSH_CONNECT_TIMEOUT)" TPU_SSH_KEY_FILE="$(TPU_QUEUE_SSH_KEY_FILE)" TPU_REQUIRE_SSH_AGENT="$(TPU_QUEUE_REQUIRE_SSH_AGENT)" TPU_QUEUE_FILTER_ZONE="$(TPU_ZONE)" TPU_QUEUE_FILTER_TYPE="$(TPU_QUEUE_TYPE)" WANDB_API_KEY="$$WANDB_API_KEY" "$(QUEUE_SCRIPT)" "$(SWEEP_ID)"
.PHONY: tpu.queue.sweep.docker
tpu.queue.sweep.docker:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
@$(MAKE) tpu.queue.sweep TPU_QUEUE_EXECUTION_MODE=docker
.PHONY: tpu.queue.worker.docker
tpu.queue.worker.docker:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
@$(MAKE) tpu.queue.worker TPU_QUEUE_EXECUTION_MODE=docker
.PHONY: tpu.queue.docker.build
tpu.queue.docker.build:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
docker build -f docker/TPUSweep.Dockerfile -t "$(TPU_QUEUE_DOCKER_IMAGE)" .
.PHONY: tpu.queue.docker.push
tpu.queue.docker.push:
@test -n "$(TPU_QUEUE_DOCKER_IMAGE)" || (echo "TPU_QUEUE_DOCKER_IMAGE is required" && exit 1)
docker push "$(TPU_QUEUE_DOCKER_IMAGE)"
.PHONY: tpu.queue.status
tpu.queue.status:
@set -e; \
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
else \
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
fi; \
for ZONE in $(TPU_QUEUE_ZONES); do \
echo "--- $$ZONE ---"; \
if ! $$QCMD list --zone="$$ZONE"; then \
echo "Skipping $$ZONE (unavailable or no permission)"; \
fi; \
done
.PHONY: tpu.queue.clean
tpu.queue.clean:
@set -e; \
if gcloud compute tpus queued-resources list --help >/dev/null 2>&1; then \
QCMD='gcloud --project=$(TPU_PROJECT) compute tpus queued-resources'; \
else \
QCMD='gcloud --project=$(TPU_PROJECT) alpha compute tpus queued-resources'; \
fi; \
for ZONE in $(TPU_QUEUE_ZONES); do \
$$QCMD list --zone="$$ZONE" --format='value(name)' 2>/dev/null | while read -r NAME; do \
case "$$NAME" in \
qr-*) echo "Deleting $$NAME ($$ZONE)"; $$QCMD delete "$$NAME" --zone="$$ZONE" --quiet ;; \
esac; \
done; \
done
.PHONY: stats.lines
stats.lines:
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \

View File

@@ -7,6 +7,19 @@ from typing import Any, NamedTuple
import numpy as np
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
from ..wandb_checkpoint import (
checkpoint_artifact_name,
download_latest_checkpoint,
log_checkpoint_bytes,
)
try:
import jax
import jax.numpy as jnp
@@ -142,6 +155,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
"checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)),
}
rollout = out["num_envs"] * out["num_steps"]
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
@@ -185,11 +199,6 @@ def make_train(config: dict[str, Any]):
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
return cfg["learning_rate"] * frac
def train(rng: jax.Array):
rng, init_key = jax.random.split(rng)
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
params = network.init(init_key, init_obs)
if cfg["anneal_lr"]:
tx = optax.chain(
optax.clip_by_global_norm(cfg["max_grad_norm"]),
@@ -200,11 +209,17 @@ def make_train(config: dict[str, Any]):
optax.clip_by_global_norm(cfg["max_grad_norm"]),
optax.adam(cfg["learning_rate"], eps=1e-5),
)
def init_runner_state(rng: jax.Array):
rng, init_key = jax.random.split(rng)
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
params = network.init(init_key, init_obs)
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
rng, reset_key = jax.random.split(rng)
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
obs, env_state = jax.vmap(env.reset)(reset_keys)
return train_state, env_state, obs, rng
def _update_step(runner_state, _):
def _env_step(runner_state, _):
@@ -261,10 +276,7 @@ def make_train(config: dict[str, Any]):
)
gae = (
delta
+ cfg["gamma"]
* cfg["gae_lambda"]
* (1.0 - transition.done)
* gae
+ cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae
)
return (gae, transition.value), gae
@@ -342,9 +354,7 @@ def make_train(config: dict[str, Any]):
),
shuffled,
)
train_state, _ = jax.lax.scan(
_update_minibatch, train_state, minibatches
)
train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches)
return (train_state, traj_batch, advantages, targets, rng), None
update_state = (train_state, traj_batch, advantages, targets, rng)
@@ -364,22 +374,23 @@ def make_train(config: dict[str, Any]):
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
}
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
next_runner_state = (train_state, env_state, last_obs, rng)
return next_runner_state, metric
runner_state = (train_state, env_state, obs, rng)
def run_updates(runner_state, *, num_updates: int):
updates = max(1, int(num_updates))
runner_state, metric = jax.lax.scan(
_update_step,
runner_state,
None,
length=cfg["num_updates"],
length=updates,
)
return {
"runner_state": runner_state,
"metrics": metric,
}
return train, network, env, cfg
return init_runner_state, run_updates, network, env, cfg
def evaluate_policy(
@@ -436,22 +447,103 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
)
train_fn, network, env, run_cfg = make_train(run_cfg)
train_jit = jax.jit(train_fn)
rng = jax.random.PRNGKey(run_cfg["seed"])
out = train_jit(rng)
init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg)
run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",))
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
total_updates = int(run_cfg["num_updates"])
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
train_state = out["runner_state"][0]
rng = jax.random.PRNGKey(run_cfg["seed"])
runner_state = init_runner_state(rng)
updates_done = 0
artifact_name = None
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(
run_cfg,
backend="jax",
sweep_id=sweep_id,
)
restored = download_latest_checkpoint(
artifact_name,
file_name="jax_runner_state.msgpack",
)
if restored is not None:
checkpoint_path, metadata = restored
template = {
"runner_state": runner_state,
"updates_done": 0,
}
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
runner_state = payload["runner_state"]
updates_done = int(payload.get("updates_done", 0))
if updates_done <= 0:
updates_done = int(metadata.get("updates_done", 0))
updates_done = max(0, min(updates_done, total_updates))
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
metric_sums = {k: 0.0 for k in metric_keys}
metric_count = 0
while updates_done < total_updates:
updates_this_segment = min(segment_updates, total_updates - updates_done)
out = run_updates_jit(runner_state, num_updates=updates_this_segment)
runner_state = out["runner_state"]
metric = out["metrics"]
segment_values = {
k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys
}
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
metric_count += segment_count
for key in metric_keys:
metric_sums[key] += float(segment_values[key].sum())
updates_done += int(updates_this_segment)
global_step = int(updates_done * rollout_steps)
if HAS_WANDB and wandb.run is not None:
wandb.log(
{
"train/reward": float(segment_values["reward"].mean()),
"train/revenue": float(segment_values["revenue"].mean()),
"train/agent_prob": float(segment_values["agent_prob"].mean()),
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
"train/global_step": global_step,
},
step=global_step,
)
if artifact_name is not None:
checkpoint_payload = serialization.to_bytes(
{
"runner_state": runner_state,
"updates_done": updates_done,
}
)
log_checkpoint_bytes(
artifact_name,
file_name="jax_runner_state.msgpack",
payload=checkpoint_payload,
metadata={
"step": global_step,
"updates_done": updates_done,
"rollout_steps": rollout_steps,
"algo": "ppo",
},
)
train_state = runner_state[0]
denom = float(metric_count) if metric_count > 0 else 1.0
metrics = {
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
"train/global_step": int(
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
),
"train/reward": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
"train/global_step": int(updates_done * rollout_steps),
}
eval_metrics = evaluate_policy(

View File

@@ -2,7 +2,7 @@ from .demand import estimate_demand, estimate_weighted_demand, generate_demand_f
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
from .render import DashboardRenderer, style_axis
from .wrappers import EconomicMetricsWrapper
from .callbacks import MetricsCallback, EvalMetricsCallback
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
from .providers import (
ProviderBenchmark,
ProviderResult,

View File

@@ -1,8 +1,12 @@
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
from pathlib import Path
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import numpy as np
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
try:
import wandb
@@ -80,6 +84,65 @@ class MetricsCallback(BaseCallback):
self._episode_revenues = []
class CheckpointArtifactCallback(BaseCallback):
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
super().__init__(verbose)
self.cfg = dict(cfg)
self.interval = max(1, int(interval))
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
self.model_dir.mkdir(parents=True, exist_ok=True)
self._next_checkpoint = self.interval
self._last_saved_step = -1
def _artifact_name(self) -> str:
sweep_id = (
getattr(wandb.run, "sweep_id", None)
if HAS_WANDB and wandb.run is not None
else None
)
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
def _checkpoint_file(self) -> Path:
algo = str(self.cfg.get("algo", "model"))
base = self.model_dir / f"phantom_{algo}_checkpoint"
self.model.save(str(base))
return base.with_suffix(".zip")
def _save_checkpoint(self) -> None:
if not HAS_WANDB or wandb.run is None:
return
step = int(self.num_timesteps)
if step <= self._last_saved_step:
return
checkpoint_path = self._checkpoint_file()
metadata = {
"step": step,
"algo": str(self.cfg.get("algo", "unknown")),
"sweep_id": getattr(wandb.run, "sweep_id", None),
}
saved = log_checkpoint_file(
self._artifact_name(),
file_path=checkpoint_path,
artifact_file_name=checkpoint_path.name,
metadata=metadata,
)
if saved:
self._last_saved_step = step
def _on_step(self) -> bool:
if self.num_timesteps < self._next_checkpoint:
return True
self._save_checkpoint()
while self._next_checkpoint <= self.num_timesteps:
self._next_checkpoint += self.interval
return True
def _on_training_end(self) -> None:
self._save_checkpoint()
class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation - true performance without exploration noise."""

View File

@@ -6,6 +6,8 @@ import os
from pathlib import Path
import numpy as np
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try:
import wandb
@@ -78,6 +80,7 @@ DEFAULT_CFG = {
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
"checkpoint_interval": 10_000,
}
@@ -262,6 +265,16 @@ def build_model(cfg: dict, env):
raise ValueError(f"unsupported algo '{algo}'")
def _sb3_model_cls(algo: str):
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
from .lib.discrete import EventQTable
@@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import MetricsCallback
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
env = make_env(cfg)
eval_env = make_env(cfg)
env = Monitor(env)
eval_env = Monitor(eval_env)
model = build_model(cfg, env)
resume_step = 0
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is not None:
checkpoint_path, metadata = restored
model = _sb3_model_cls(cfg["algo"]).load(
checkpoint_path.as_posix(), env=env
)
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
model.num_timesteps = max(
int(getattr(model, "num_timesteps", 0)), resume_step
)
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
cbs.append(
CheckpointArtifactCallback(
cfg,
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
cbs.append(
EvalCallback(
eval_env,
@@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
verbose=0,
)
)
model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=cbs,
reset_num_timesteps=False,
)
model_path = Path(cfg["model_dir"])
model_path.mkdir(parents=True, exist_ok=True)
model.save(str(model_path / f"phantom_{cfg['algo']}"))
@@ -413,6 +456,7 @@ def main():
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--checkpoint-interval", type=int)
p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0)
@@ -441,6 +485,7 @@ def main():
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"checkpoint_interval": args.checkpoint_interval,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,