cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -27,10 +27,6 @@ AGENT_LOOP ?= 1
RETRY_SECONDS ?= 20 RETRY_SECONDS ?= 20
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
TPU_NAME ?=
TPU_ZONE ?= us-central2-b
TPU_PROJECT ?= phantom-trc
TPU_REPO_DIR ?= /tmp/PHANTOM
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
@@ -38,7 +34,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
.PHONY: help .PHONY: help
help: help:
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | stats.lines"
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish" @echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
@echo "" @echo ""
@echo "Build general public version:" @echo "Build general public version:"
@@ -137,26 +133,6 @@ wordcount:
docker.train.publish: docker.train.publish:
@TRAIN_IMAGE_REF="$(TRAIN_IMAGE_REF)" $(NX) run research:docker-train-publish @TRAIN_IMAGE_REF="$(TRAIN_IMAGE_REF)" $(NX) run research:docker-train-publish
.PHONY: train.tpu.pod
train.tpu.pod:
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-pod
.PHONY: train.tpu.vm.prepare
train.tpu.vm.prepare:
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" $(NX) run research:train-tpu-vm-prepare
.PHONY: train.tpu.vm.run
train.tpu.vm.run:
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm-run
.PHONY: train.tpu.vm
train.tpu.vm:
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm
.PHONY: train.tpu.vm.sweep
train.tpu.vm.sweep:
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-vm-sweep
.PHONY: backend.server backend.provider backend.worker platform.up platform.down platform.logs .PHONY: backend.server backend.provider backend.worker platform.up platform.down platform.logs
backend.server: backend.server:
@$(NX) run backend-server:dev @$(NX) run backend-server:dev

View File

@@ -7,36 +7,9 @@ WORKDIR /app
COPY docker/trainer.requirements.txt /tmp/requirements.txt COPY docker/trainer.requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt RUN pip install --no-cache-dir -r /tmp/requirements.txt
# Optional for JAX-on-GPU workflows.
ARG INSTALL_JAX_GPU=false
RUN if [ "${INSTALL_JAX_GPU}" = "true" ]; then \
pip install --no-cache-dir "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \
fi
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
COPY engine /app/engine COPY engine /app/engine
ENV PYTHONPATH=/app \ ENV PYTHONPATH=/app
XLA_PYTHON_CLIENT_PREALLOCATE=false
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
FROM python:3.11-slim AS tpu
WORKDIR /app
COPY docker/trainer.requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt
RUN pip install --no-cache-dir "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
COPY engine /app/engine
ENV PYTHONPATH=/app \
PHANTOM_USE_JAX=1 \
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
XLA_PYTHON_CLIENT_PREALLOCATE=false
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"] ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]

View File

@@ -5,9 +5,3 @@ gymnasium>=0.29.0
stable-baselines3>=2.2.0 stable-baselines3>=2.2.0
tensorboard>=2.15.0 tensorboard>=2.15.0
wandb>=0.17.0 wandb>=0.17.0
tensorflow-probability==0.24.0
flax==0.10.7
optax==0.2.7
distrax==0.1.5
orbax-checkpoint==0.11.32
chex==0.1.90

View File

@@ -1 +1 @@
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"] __all__ = ["evaluate", "make_env", "train_qtable", "train_sb3"]

View File

@@ -1,18 +0,0 @@
from __future__ import annotations
from typing import Any, Mapping
from ..jax import JAX_AVAILABLE
def train_jax_backend(
cfg: Mapping[str, Any],
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
from ..jax.train import train_jax
return train_jax(dict(cfg))

View File

@@ -7,7 +7,9 @@ import numpy as np
from .common import evaluate, make_env from .common import evaluate, make_env
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]: def train_qtable(
cfg: Mapping[str, Any],
) -> tuple[object, dict[str, Any]]:
from ..lib.discrete import EventQTable from ..lib.discrete import EventQTable
np.random.seed(int(cfg["seed"])) np.random.seed(int(cfg["seed"]))
@@ -26,8 +28,19 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
total_revenue = 0.0 total_revenue = 0.0
steps = 0 steps = 0
epsilon = float(cfg["eps_start"]) epsilon = float(cfg["eps_start"])
log_freq = max(1, int(cfg.get("log_freq", 100)))
obs, _ = env.reset(seed=int(cfg["seed"])) obs, _ = env.reset(seed=int(cfg["seed"]))
interval_sums = {
"reward": 0.0,
"revenue": 0.0,
"agent_prob": 0.0,
"alpha_adv": 0.0,
"coi_leakage": 0.0,
}
interval_count = 0
train_events: list[dict[str, float | int]] = []
for _ in range(int(cfg["total_timesteps"])): for _ in range(int(cfg["total_timesteps"])):
action, state = agent.act(obs, epsilon) action, state = agent.act(obs, epsilon)
nxt, reward, term, trunc, info = env.step(action) nxt, reward, term, trunc, info = env.step(action)
@@ -35,18 +48,57 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
agent.update(state, action, float(reward), agent.encode(nxt), done) agent.update(state, action, float(reward), agent.encode(nxt), done)
total_reward += float(reward) total_reward += float(reward)
total_revenue += float(info.get("economics", {}).get("revenue", 0.0)) revenue = float(info.get("economics", {}).get("revenue", 0.0))
total_revenue += revenue
steps += 1 steps += 1
interval_sums["reward"] += float(reward)
interval_sums["revenue"] += revenue
interval_sums["agent_prob"] += float(info.get("agent_prob", 0.0))
interval_sums["alpha_adv"] += float(info.get("alpha_adv", 0.0))
interval_sums["coi_leakage"] += float(info.get("coi_leakage", 0.0))
interval_count += 1
if steps % log_freq == 0 and interval_count > 0:
denom = float(interval_count)
train_events.append(
{
"train/reward_mean": interval_sums["reward"] / denom,
"train/revenue_mean": interval_sums["revenue"] / denom,
"train/agent_prob": interval_sums["agent_prob"] / denom,
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
"train/epsilon": float(epsilon),
"train/global_step": int(steps),
}
)
interval_sums = {key: 0.0 for key in interval_sums}
interval_count = 0
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"])) epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
obs = env.reset()[0] if done else nxt obs = env.reset()[0] if done else nxt
metrics: dict[str, float | int] = { if interval_count > 0:
denom = float(interval_count)
train_events.append(
{
"train/reward_mean": interval_sums["reward"] / denom,
"train/revenue_mean": interval_sums["revenue"] / denom,
"train/agent_prob": interval_sums["agent_prob"] / denom,
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
"train/epsilon": float(epsilon),
"train/global_step": int(steps),
}
)
metrics: dict[str, Any] = {
"train/reward_mean": total_reward / max(steps, 1), "train/reward_mean": total_reward / max(steps, 1),
"train/revenue_mean": total_revenue / max(steps, 1), "train/revenue_mean": total_revenue / max(steps, 1),
"train/epsilon": float(epsilon), "train/epsilon": float(epsilon),
"train/global_step": int(cfg["total_timesteps"]), "train/global_step": int(cfg["total_timesteps"]),
} }
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"]))) metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
metrics["_train_events"] = train_events
env.close() env.close()
eval_env.close() eval_env.close()

View File

@@ -4,9 +4,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any, Mapping from typing import Any, Mapping
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback from ..lib.callbacks import MetricsCallback
from ..telemetry.wandb import get_wandb_module
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
from .common import evaluate, make_env from .common import evaluate, make_env
@@ -52,21 +50,6 @@ def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
return kwargs return kwargs
def _sb3_model_cls(algo: str):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def build_model(cfg: Mapping[str, Any], env: Any): def build_model(cfg: Mapping[str, Any], env: Any):
try: try:
from stable_baselines3 import A2C, DQN, PPO from stable_baselines3 import A2C, DQN, PPO
@@ -132,29 +115,7 @@ def build_model(cfg: Mapping[str, Any], env: Any):
raise ValueError(f"unsupported algo '{algo}'") raise ValueError(f"unsupported algo '{algo}'")
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any): def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return model
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is None:
return model
checkpoint_path, metadata = restored
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
checkpoint_path.as_posix(),
env=env,
)
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
return resumed
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
try: try:
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
@@ -182,15 +143,10 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
except Exception: except Exception:
pass pass
model = _maybe_resume_model(cfg, env, model) metrics_callback = MetricsCallback(
log_histograms=False, log_freq=int(cfg["log_freq"])
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
callbacks.append(
CheckpointArtifactCallback(
dict(cfg),
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
) )
callbacks = [metrics_callback]
callbacks.append( callbacks.append(
EvalCallback( EvalCallback(
eval_env, eval_env,
@@ -215,13 +171,14 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
model_path = model_dir / f"phantom_{cfg['algo']}" model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path)) model.save(str(model_path))
metrics: dict[str, float | int | str] = evaluate( metrics: dict[str, Any] = evaluate(
model, model,
eval_env, eval_env,
int(cfg["eval_episodes"]), int(cfg["eval_episodes"]),
) )
metrics["train/global_step"] = int(model.num_timesteps) metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip")) metrics["model/path"] = str(model_path.with_suffix(".zip"))
metrics["_train_events"] = list(metrics_callback.events)
env.close() env.close()
eval_env.close() eval_env.close()

View File

@@ -1,13 +0,0 @@
"""JAX-compatible training and environment modules for PHANTOM."""
from __future__ import annotations
try:
import jax # noqa: F401
import jax.numpy as jnp # noqa: F401
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
__all__ = ["JAX_AVAILABLE"]

View File

@@ -1,49 +0,0 @@
"""Orbax checkpoint helpers for JAX training runs."""
from __future__ import annotations
from pathlib import Path
from typing import Any
try:
import orbax.checkpoint as ocp
HAS_ORBAX = True
except ImportError:
HAS_ORBAX = False
def _require_orbax() -> None:
if not HAS_ORBAX:
raise ImportError(
"orbax-checkpoint is required for checkpoint support. "
"Install engine/jax/requirements.txt first."
)
def create_manager(directory: str | Path, max_to_keep: int = 5):
_require_orbax()
root = Path(directory)
root.mkdir(parents=True, exist_ok=True)
options = ocp.CheckpointManagerOptions(
max_to_keep=max(1, int(max_to_keep)), create=True
)
return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options)
def save(manager, *, step: int, payload: Any) -> bool:
_require_orbax()
return bool(manager.save(int(step), payload))
def latest_step(manager) -> int | None:
_require_orbax()
return manager.latest_step()
def restore(manager, *, target: Any, step: int | None = None) -> Any:
_require_orbax()
step_to_restore = manager.latest_step() if step is None else int(step)
if step_to_restore is None:
return target
return manager.restore(step_to_restore, items=target)

View File

@@ -1,304 +0,0 @@
"""JAX-native PHANTOM environment with robust contamination step."""
from __future__ import annotations
from typing import NamedTuple
try:
import jax
import jax.numpy as jnp
except ImportError as exc: # pragma: no cover
raise ImportError("engine.jax.env requires JAX") from exc
from .primitives import (
_sample_sessions_jax,
agent_probability_from_kl,
batch_kl,
compute_session_transitions,
load_transition_data,
purchase_flags,
reward_with_coi_penalty,
revenue_from_demand,
weighted_demand,
)
class EnvParams(NamedTuple):
n_products: int
n_sessions: int
max_episode_steps: int
max_session_steps: int
price_low: float
price_high: float
lambda_coi: float
info_value: float
eta_ux: float
robust_radius: float
margin_floor: float
margin_floor_patience: int
action_scales: jax.Array
alpha_nominal: float
alpha_candidates: jax.Array
human_T: jax.Array
agent_T: jax.Array
terminal_mask: jax.Array
purchase_mask: jax.Array
event_weights: jax.Array
start_idx: int
term_idx: int
class EnvState(NamedTuple):
prices: jax.Array
demand: jax.Array
step_count: jax.Array
low_margin_streak: jax.Array
last_agent_prob: jax.Array
last_alpha_adv: jax.Array
class CandidateEval(NamedTuple):
reward: jax.Array
revenue: jax.Array
demand: jax.Array
agent_prob: jax.Array
leakage: jax.Array
discount: jax.Array
ux_penalty: jax.Array
n_purchases: jax.Array
n_agents: jax.Array
def make_env_params(
*,
n_products: int,
alpha: float,
n_sessions: int,
lambda_coi: float,
robust_radius: float,
robust_points: int,
info_value: float,
eta_ux: float = 0.5,
action_levels: int,
action_scale_low: float,
action_scale_high: float,
price_low: float,
price_high: float,
max_episode_steps: int,
max_session_steps: int = 40,
margin_floor: float = 0.05,
margin_floor_patience: int = 5,
prefer_behavior_data: bool = True,
) -> EnvParams:
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
if robust_radius <= 0.0 or robust_points <= 1:
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
else:
lo = max(0.0, float(alpha) - float(robust_radius))
hi = min(1.0, float(alpha) + float(robust_radius))
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
action_scales = jnp.linspace(
float(action_scale_low),
float(action_scale_high),
int(action_levels),
dtype=jnp.float32,
)
return EnvParams(
n_products=int(n_products),
n_sessions=int(n_sessions),
max_episode_steps=int(max_episode_steps),
max_session_steps=int(max_session_steps),
price_low=float(price_low),
price_high=float(price_high),
lambda_coi=float(lambda_coi),
info_value=float(info_value),
eta_ux=float(eta_ux),
robust_radius=float(robust_radius),
margin_floor=float(margin_floor),
margin_floor_patience=int(margin_floor_patience),
action_scales=action_scales,
alpha_nominal=float(alpha),
alpha_candidates=alpha_candidates,
human_T=jnp.asarray(transition.human_T),
agent_T=jnp.asarray(transition.agent_T),
terminal_mask=jnp.asarray(transition.terminal_mask),
purchase_mask=jnp.asarray(transition.purchase_mask),
event_weights=jnp.asarray(transition.event_weights),
start_idx=int(transition.start_idx),
term_idx=int(transition.term_idx),
)
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
def _decode_action(
prices: jax.Array, action: jax.Array, params: EnvParams
) -> jax.Array:
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
scale = params.action_scales[idx]
next_prices = prices * scale
return jnp.clip(next_prices, params.price_low, params.price_high)
def _evaluate_candidate(
key: jax.Array,
alpha_candidate: jax.Array,
prices: jax.Array,
ux_volatility: jax.Array,
params: EnvParams,
) -> CandidateEval:
states, products, actors, lengths = _sample_sessions_jax(
key,
params.human_T,
params.agent_T,
params.terminal_mask,
params.start_idx,
params.term_idx,
alpha_candidate,
params.n_products,
params.n_sessions,
params.max_session_steps,
int(params.human_T.shape[0]),
)
session_trans = compute_session_transitions(
states, lengths, int(params.human_T.shape[0])
)
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
agent_probs = agent_probability_from_kl(delta_h, delta_a)
agent_prob = jnp.mean(agent_probs)
demand = weighted_demand(states, products, params.n_products, params.event_weights)
revenue = revenue_from_demand(prices, demand)
reward, leakage, discount, ux_penalty = reward_with_coi_penalty(
revenue,
agent_prob,
params.lambda_coi,
params.info_value,
params.eta_ux,
ux_volatility,
)
purchases = purchase_flags(states, params.purchase_mask)
return CandidateEval(
reward=reward,
revenue=revenue,
demand=demand,
agent_prob=agent_prob,
leakage=leakage,
discount=discount,
ux_penalty=ux_penalty,
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
n_agents=jnp.sum(actors.astype(jnp.float32)),
)
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
prices = jax.random.uniform(
key,
shape=(params.n_products,),
minval=params.price_low,
maxval=params.price_high,
)
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
state = EnvState(
prices=prices,
demand=demand,
step_count=jnp.asarray(0, dtype=jnp.int32),
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
)
return _flatten_obs(demand, prices), state
def step_env(
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams,
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
prices = _decode_action(state.prices, action, params)
baseline = jnp.maximum(state.prices, 1.0)
ux_volatility = jnp.where(
state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0
)
n_candidates = params.alpha_candidates.shape[0]
cand_keys = jax.random.split(key, n_candidates)
evals = jax.vmap(
lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params),
in_axes=(0, 0),
)(cand_keys, params.alpha_candidates)
idx = jnp.argmin(evals.reward)
demand = evals.demand[idx]
reward = evals.reward[idx]
revenue = evals.revenue[idx]
agent_prob = evals.agent_prob[idx]
leakage = evals.leakage[idx]
discount = evals.discount[idx]
ux_penalty = evals.ux_penalty[idx]
n_purchases = evals.n_purchases[idx]
n_agents = evals.n_agents[idx]
alpha_adv = params.alpha_candidates[idx]
step_count = state.step_count + 1
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
avg_margin = (avg_price - params.price_low) / avg_price
next_streak = jnp.where(
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
)
margin_collapsed = next_streak >= params.margin_floor_patience
done = (step_count >= params.max_episode_steps) | margin_collapsed
next_state = EnvState(
prices=prices,
demand=demand,
step_count=step_count,
low_margin_streak=next_streak,
last_agent_prob=agent_prob,
last_alpha_adv=alpha_adv,
)
obs = _flatten_obs(demand, prices)
info = {
"revenue": revenue,
"agent_prob": agent_prob,
"alpha_adv": alpha_adv,
"coi_leakage": leakage,
"coi_discount": discount,
"ux_penalty": ux_penalty,
"volatility": ux_volatility,
"n_purchases": n_purchases,
"n_agents": n_agents,
"avg_margin": avg_margin,
}
return obs, next_state, reward, done, info
class PHANTOMJAXEnv:
def __init__(self, params: EnvParams):
self.params = params
def reset(self, key: jax.Array, params: EnvParams | None = None):
return reset_env(key, self.params if params is None else params)
def step(
self,
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams | None = None,
):
return step_env(key, state, action, self.params if params is None else params)
def action_space_n(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.action_scales.shape[0])
def observation_dim(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.n_products * 2)

View File

@@ -1,501 +0,0 @@
"""JAX-compatible primitives for PHANTOM session simulation and separability."""
from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import Mapping
import numpy as np
try:
import jax
import jax.numpy as jnp
JAX_AVAILABLE = True
except ImportError:
jax = None # type: ignore[assignment]
jnp = np # type: ignore[assignment]
JAX_AVAILABLE = False
STATE_START_KEYS = ("session_start", "start")
TERMINAL_EVENT_TOKENS = (
"session_end",
"end",
"purchase_complete",
"checkout_start",
"checkout",
)
PURCHASE_EVENT_TOKENS = (
"purchase_complete",
"purchase",
"checkout_start",
"checkout",
)
CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5}
ACTION_CATEGORIES = {
"cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"},
"dwell": {
"hover_title",
"hover_paragraph",
"hover_link",
"hover_over_title",
"hover_over_paragraph",
"hover_over_link",
"hover_over_button",
},
"nav": {
"page_view",
"view_item",
"view",
"learn_more",
"learn_more_about_item",
"view_item_page",
"session_start",
},
"filter": {
"search",
"filter_date",
"filter_price",
"sort",
"filter_for_date",
"filter_for_price",
"filter_for_amenities",
"sort_change",
},
}
DEFAULT_ACTION_WEIGHTS = {
action: CATEGORY_WEIGHTS[group]
for group, actions in ACTION_CATEGORIES.items()
for action in actions
}
@dataclass(frozen=True)
class TransitionData:
"""Dense transition kernels and per-state metadata."""
human_T: np.ndarray
agent_T: np.ndarray
terminal_mask: np.ndarray
purchase_mask: np.ndarray
event_weights: np.ndarray
event_names: tuple[str, ...]
start_idx: int
term_idx: int
def to_jax(self) -> "TransitionData":
if not JAX_AVAILABLE:
return self
return TransitionData(
human_T=jnp.asarray(self.human_T),
agent_T=jnp.asarray(self.agent_T),
terminal_mask=jnp.asarray(self.terminal_mask),
purchase_mask=jnp.asarray(self.purchase_mask),
event_weights=jnp.asarray(self.event_weights),
event_names=self.event_names,
start_idx=int(self.start_idx),
term_idx=int(self.term_idx),
)
@dataclass(frozen=True)
class SessionBatch:
states: np.ndarray
products: np.ndarray
actors: np.ndarray
lengths: np.ndarray
def _event_weight(name: str) -> float:
if name in DEFAULT_ACTION_WEIGHTS:
return float(DEFAULT_ACTION_WEIGHTS[name])
if name.startswith("hover"):
return float(CATEGORY_WEIGHTS["dwell"])
if name.startswith("filter") or name in {"search", "sort", "sort_change"}:
return float(CATEGORY_WEIGHTS["filter"])
if name.startswith("add") or name in {
"checkout",
"checkout_start",
"purchase",
"remove_item",
"purchase_complete",
}:
return float(CATEGORY_WEIGHTS["cart"])
if any(token in name for token in TERMINAL_EVENT_TOKENS):
return 0.0
return float(CATEGORY_WEIGHTS["nav"])
def _is_terminal(name: str) -> bool:
return any(token in name for token in TERMINAL_EVENT_TOKENS)
def _is_purchase(name: str) -> bool:
return any(token in name for token in PURCHASE_EVENT_TOKENS)
def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]:
names: set[str] = set()
for trans in transitions:
for src, dsts in trans.items():
names.add(src)
names.update(dsts.keys())
names.discard("__terminal__")
return tuple(sorted(names))
def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray:
row_sums = matrix.sum(axis=1, keepdims=True)
dead_rows = np.isclose(row_sums.squeeze(-1), 0.0)
if np.any(dead_rows):
matrix[dead_rows] = 0.0
matrix[dead_rows, term_idx] = 1.0
row_sums = matrix.sum(axis=1, keepdims=True)
return matrix / np.maximum(row_sums, 1e-8)
def _dense_from_dict(
transitions: Mapping[str, Mapping[str, float]],
event_to_idx: Mapping[str, int],
term_idx: int,
) -> np.ndarray:
n_states = len(event_to_idx)
matrix = np.zeros((n_states, n_states), dtype=np.float32)
for src, dsts in transitions.items():
i = event_to_idx.get(src)
if i is None:
continue
for dst, prob in dsts.items():
j = event_to_idx.get(dst)
if j is None:
continue
matrix[i, j] += float(prob)
return _normalize_rows(matrix, term_idx)
def compile_transition_data(
human_transitions: Mapping[str, Mapping[str, float]],
agent_transitions: Mapping[str, Mapping[str, float]],
) -> TransitionData:
event_names = _collect_events(human_transitions, agent_transitions)
if not event_names:
return fallback_transition_data()
event_names = tuple([*event_names, "__terminal__"])
term_idx = len(event_names) - 1
event_to_idx = {name: i for i, name in enumerate(event_names)}
human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx)
agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx)
terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool)
purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool)
event_weights = np.array(
[_event_weight(name) for name in event_names], dtype=np.float32
)
terminal_mask[term_idx] = True
for idx, is_term in enumerate(terminal_mask):
if not is_term:
continue
human_T[idx] = 0.0
agent_T[idx] = 0.0
human_T[idx, idx] = 1.0
agent_T[idx, idx] = 1.0
start_idx = 0
for key in STATE_START_KEYS:
if key in event_to_idx:
start_idx = int(event_to_idx[key])
break
return TransitionData(
human_T=human_T,
agent_T=agent_T,
terminal_mask=terminal_mask,
purchase_mask=purchase_mask,
event_weights=event_weights,
event_names=event_names,
start_idx=start_idx,
term_idx=term_idx,
)
def fallback_transition_data() -> TransitionData:
human = {
"session_start": {
"page_view": 0.80,
"view_item_page": 0.15,
"session_end": 0.05,
},
"page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20},
"view_item_page": {
"learn_more_about_item": 0.40,
"add_item_to_cart": 0.28,
"session_end": 0.32,
},
"learn_more_about_item": {
"add_item_to_cart": 0.50,
"view_item_page": 0.30,
"session_end": 0.20,
},
"add_item_to_cart": {
"checkout_start": 0.58,
"view_item_page": 0.24,
"session_end": 0.18,
},
"checkout_start": {"purchase_complete": 0.70, "session_end": 0.30},
"purchase_complete": {"session_end": 1.0},
}
agent = {
"session_start": {
"page_view": 0.90,
"view_item_page": 0.08,
"session_end": 0.02,
},
"page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25},
"view_item_page": {
"learn_more_about_item": 0.55,
"add_item_to_cart": 0.15,
"session_end": 0.30,
},
"learn_more_about_item": {
"view_item_page": 0.45,
"add_item_to_cart": 0.20,
"session_end": 0.35,
},
"add_item_to_cart": {
"checkout_start": 0.42,
"view_item_page": 0.28,
"session_end": 0.30,
},
"checkout_start": {"purchase_complete": 0.52, "session_end": 0.48},
"purchase_complete": {"session_end": 1.0},
}
return compile_transition_data(human, agent)
def load_transition_data(prefer_data: bool = True) -> TransitionData:
if not prefer_data:
return fallback_transition_data()
try:
from ..lib.behavior import get_transition_models
human_trans, agent_trans = get_transition_models()
return compile_transition_data(human_trans, agent_trans)
except Exception:
return fallback_transition_data()
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(8, 9, 10))
def _sample_sessions_jax(
key: jax.Array,
human_T: jax.Array,
agent_T: jax.Array,
terminal_mask: jax.Array,
start_idx: int,
term_idx: int,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3)
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint(
k_product, (n_sessions,), 0, n_products, dtype=jnp.int32
)
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
def _scan_step(carry, _):
states, active, rng = carry
rng, k = jax.random.split(rng)
probs_h = human_T[states]
probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, term_idx_i32)
emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, term_idx_i32)
return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan(
_scan_step, (state_init, active_init, k_step), None, length=max_steps
)
states = state_t.T
lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32)
return states, products, actors, lengths
def sample_sessions(
key,
transition_data: TransitionData,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
) -> SessionBatch:
if JAX_AVAILABLE:
td = transition_data.to_jax()
states, products, actors, lengths = _sample_sessions_jax(
key,
td.human_T,
td.agent_T,
td.terminal_mask,
int(td.start_idx),
int(td.term_idx),
float(alpha),
int(n_products),
int(n_sessions),
int(max_steps),
int(td.human_T.shape[0]),
)
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0]))
n_states = transition_data.human_T.shape[0]
products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32)
actors = (rng.random(size=n_sessions) < alpha).astype(np.int32)
states = np.full((n_sessions, max_steps), -1, dtype=np.int32)
lengths = np.zeros((n_sessions,), dtype=np.int32)
for i in range(n_sessions):
current = int(transition_data.start_idx)
mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T
for t in range(max_steps):
nxt = int(rng.choice(n_states, p=mat[current]))
states[i, t] = nxt
if transition_data.terminal_mask[nxt]:
lengths[i] = t + 1
break
current = nxt
if lengths[i] == 0:
lengths[i] = max_steps
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(2,))
def compute_session_transitions(states, lengths, n_states: int):
src = states[:, :-1]
dst = states[:, 1:]
time_idx = jnp.arange(src.shape[1])[None, :]
valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1))
src_clip = jnp.clip(src, 0, n_states - 1)
dst_clip = jnp.clip(dst, 0, n_states - 1)
src_oh = jax.nn.one_hot(src_clip, n_states)
dst_oh = jax.nn.one_hot(dst_clip, n_states)
counts = jnp.einsum(
"nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32)
)
row_sums = jnp.sum(counts, axis=-1, keepdims=True)
return counts / (row_sums + 1e-10)
else:
def compute_session_transitions(states, lengths, n_states: int):
trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32)
for i in range(states.shape[0]):
for t in range(max(int(lengths[i]) - 1, 0)):
s = int(states[i, t])
d = int(states[i, t + 1])
if s >= 0 and d >= 0:
trans[i, s, d] += 1.0
row_sums = trans.sum(axis=-1, keepdims=True)
return trans / (row_sums + 1e-10)
def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10):
p = P + eps
p = p / jnp.sum(p, axis=-1, keepdims=True)
qh = Q_human[None, ...] + eps
qa = Q_agent[None, ...] + eps
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
return delta_h, delta_a
if JAX_AVAILABLE:
batch_kl = jax.jit(batch_kl)
def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0):
t = jnp.maximum(float(temperature), 1e-6)
exp_h = jnp.exp(-delta_h / t)
exp_a = jnp.exp(-delta_a / t)
return exp_a / (exp_h + exp_a + 1e-10)
def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0):
logits = beta * (delta_h - delta_a)
return 1.0 / (1.0 + jnp.exp(-logits))
def weighted_demand(states, products, n_products: int, event_weights):
valid = states >= 0
state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1)
weights = event_weights[state_clip] * valid
per_session = jnp.sum(weights, axis=1)
demand = jnp.zeros((n_products,), dtype=jnp.float32)
demand = demand.at[products].add(per_session)
total = jnp.sum(demand)
return jnp.where(total > 0.0, (demand / total) * 100.0, demand)
if JAX_AVAILABLE:
weighted_demand = jax.jit(weighted_demand, static_argnums=(2,))
def purchase_flags(states, purchase_mask):
state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1)
hits = purchase_mask[state_clip] & (states >= 0)
return jnp.any(hits, axis=1)
if JAX_AVAILABLE:
purchase_flags = jax.jit(purchase_flags)
def revenue_from_demand(prices, demand):
return jnp.dot(prices, demand)
if JAX_AVAILABLE:
revenue_from_demand = jax.jit(revenue_from_demand)
def reward_with_coi_penalty(
revenue,
agent_prob: float,
lambda_coi: float,
info_value: float,
eta_ux: float = 0.0,
ux_volatility: float = 0.0,
):
leakage = agent_prob * info_value
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
ux_penalty = eta_ux * revenue * ux_volatility
return revenue * discount - ux_penalty, leakage, discount, ux_penalty
if JAX_AVAILABLE:
reward_with_coi_penalty = jax.jit(reward_with_coi_penalty)

View File

@@ -1,5 +0,0 @@
flax==0.10.7
optax==0.2.7
distrax==0.1.5
orbax-checkpoint==0.11.32
chex==0.1.90

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@ _EXPORTS: dict[str, tuple[str, str]] = {
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"), "EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
"MetricsCallback": (".callbacks", "MetricsCallback"), "MetricsCallback": (".callbacks", "MetricsCallback"),
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"), "EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
"CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
"ProviderBenchmark": (".providers", "ProviderBenchmark"), "ProviderBenchmark": (".providers", "ProviderBenchmark"),
"ProviderResult": (".providers", "ProviderResult"), "ProviderResult": (".providers", "ProviderResult"),
"BenchmarkConfig": (".providers", "BenchmarkConfig"), "BenchmarkConfig": (".providers", "BenchmarkConfig"),

View File

@@ -1,150 +1,96 @@
"""Training callbacks for W&B/TensorBoard logging - reads from info dict.""" """Training callbacks with algorithm-agnostic metric extraction."""
from pathlib import Path from typing import Any
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import numpy as np import numpy as np
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
class MetricsCallback(BaseCallback): class MetricsCallback(BaseCallback):
"""Training metrics logger - reads info['economics'], logs to W&B.""" """Collects interval train metrics from env info dictionaries."""
def __init__( def __init__(
self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0 self,
log_histograms: bool = False,
log_freq: int = 100,
verbose: int = 0,
): ):
super().__init__(verbose) super().__init__(verbose)
self.log_histograms = log_histograms self.log_histograms = log_histograms
self.log_freq = log_freq self.log_freq = max(1, int(log_freq))
self._episode_revenues: list[float] = [] self._window_sums = {
"train/revenue_mean": 0.0,
def _on_step(self) -> bool: "train/margin_mean": 0.0,
if not HAS_WANDB or wandb.run is None: "train/coi_level_mean": 0.0,
return True "train/regret_mean": 0.0,
"train/coi_mix": 0.0,
for info in self.locals.get("infos", []): "train/coi_base": 0.0,
if "economics" not in info: "train/coi_leakage": 0.0,
continue "train/coi_penalty": 0.0,
econ = info["economics"]
t = self.num_timesteps
payload = {
"train/revenue_step": econ["revenue"],
"train/margin_step": econ["margin"],
"train/coi_level": econ["coi_level"],
"train/regret_step": econ["regret"],
}
if "coi_mix" in econ:
payload["train/coi_mix"] = econ["coi_mix"]
if "coi_base" in econ:
payload["train/coi_base"] = econ["coi_base"]
if "coi_leakage" in econ:
payload["train/coi_leakage"] = econ["coi_leakage"]
if "coi_penalty" in econ:
payload["train/coi_penalty"] = econ["coi_penalty"]
wandb.log(payload, step=t)
self._episode_revenues.append(econ["revenue"])
# histograms at log_freq intervals
if self.log_histograms and self.num_timesteps % self.log_freq == 0:
for info in self.locals.get("infos", []):
if "prices" in info:
wandb.log(
{"distributions/prices": wandb.Histogram(info["prices"])},
step=self.num_timesteps,
)
if "demand" in info:
wandb.log(
{"distributions/demand": wandb.Histogram(info["demand"])},
step=self.num_timesteps,
)
return True
def _on_rollout_end(self) -> None:
if not HAS_WANDB or wandb.run is None or not self._episode_revenues:
return
wandb.log(
{
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
"train/revenue_rollout_total": np.sum(self._episode_revenues),
},
step=self.num_timesteps,
)
self._episode_revenues = []
class CheckpointArtifactCallback(BaseCallback):
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
super().__init__(verbose)
self.cfg = dict(cfg)
self.interval = max(1, int(interval))
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
self.model_dir.mkdir(parents=True, exist_ok=True)
self._next_checkpoint = self.interval
self._last_saved_step = -1
def _artifact_name(self) -> str:
sweep_id = (
getattr(wandb.run, "sweep_id", None)
if HAS_WANDB and wandb.run is not None
else None
)
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
def _checkpoint_file(self) -> Path:
algo = str(self.cfg.get("algo", "model"))
base = self.model_dir / f"phantom_{algo}_checkpoint"
self.model.save(str(base))
return base.with_suffix(".zip")
def _save_checkpoint(self) -> None:
if not HAS_WANDB or wandb.run is None:
return
step = int(self.num_timesteps)
if step <= self._last_saved_step:
return
checkpoint_path = self._checkpoint_file()
metadata = {
"step": step,
"algo": str(self.cfg.get("algo", "unknown")),
"sweep_id": getattr(wandb.run, "sweep_id", None),
} }
saved = log_checkpoint_file( self._window_count = 0
self._artifact_name(), self.events: list[dict[str, Any]] = []
file_path=checkpoint_path,
artifact_file_name=checkpoint_path.name, def _accumulate(self, info: dict[str, Any]) -> None:
metadata=metadata, econ = info.get("economics")
) if not isinstance(econ, dict):
if saved: return
self._last_saved_step = step self._window_sums["train/revenue_mean"] += float(econ.get("revenue", 0.0))
self._window_sums["train/margin_mean"] += float(econ.get("margin", 0.0))
self._window_sums["train/coi_level_mean"] += float(econ.get("coi_level", 0.0))
self._window_sums["train/regret_mean"] += float(econ.get("regret", 0.0))
if "coi_mix" in econ:
self._window_sums["train/coi_mix"] += float(econ.get("coi_mix", 0.0))
if "coi_base" in econ:
self._window_sums["train/coi_base"] += float(econ.get("coi_base", 0.0))
if "coi_leakage" in econ:
self._window_sums["train/coi_leakage"] += float(
econ.get("coi_leakage", 0.0)
)
if "coi_penalty" in econ:
self._window_sums["train/coi_penalty"] += float(
econ.get("coi_penalty", 0.0)
)
self._window_count += 1
def _flush(self, step: int) -> None:
if self._window_count <= 0:
return
denom = float(self._window_count)
payload = {
key: (value / denom)
for key, value in self._window_sums.items()
if value != 0.0
or key
in {
"train/revenue_mean",
"train/margin_mean",
"train/coi_level_mean",
"train/regret_mean",
}
}
payload["train/global_step"] = int(step)
self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0
def _on_step(self) -> bool: def _on_step(self) -> bool:
if self.num_timesteps < self._next_checkpoint: for info in self.locals.get("infos", []):
return True if isinstance(info, dict):
self._save_checkpoint() self._accumulate(info)
while self._next_checkpoint <= self.num_timesteps:
self._next_checkpoint += self.interval if self.num_timesteps % self.log_freq == 0:
self._flush(step=self.num_timesteps)
return True return True
def _on_training_end(self) -> None: def _on_training_end(self) -> None:
self._save_checkpoint() self._flush(step=self.num_timesteps)
class EvalMetricsCallback(EvalCallback): class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation - true performance without exploration noise.""" """Deterministic evaluation collector detached from logging backends."""
def __init__( def __init__(
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
@@ -153,23 +99,19 @@ class EvalMetricsCallback(EvalCallback):
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
) )
self._eval_revenues: list[float] = [] self._eval_revenues: list[float] = []
self.events: list[dict[str, float | int]] = []
def _on_step(self) -> bool: def _on_step(self) -> bool:
result = super()._on_step() result = super()._on_step()
if not HAS_WANDB or wandb.run is None:
return result
# log eval metrics after evaluation runs
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
wandb.log( self.events.append(
{ {
"eval/reward_mean": self.last_mean_reward, "eval/reward_mean": float(self.last_mean_reward),
"eval/revenue_mean": np.mean(self._eval_revenues) "eval/revenue_mean": float(np.mean(self._eval_revenues))
if self._eval_revenues if self._eval_revenues
else 0, else 0.0,
}, "train/global_step": int(self.num_timesteps),
step=self.num_timesteps, }
) )
self._eval_revenues = [] self._eval_revenues = []

View File

@@ -31,26 +31,20 @@ def _print_local_metrics(metrics: dict[str, Any]) -> None:
print("PHANTOM_METRICS:" + json.dumps(metrics)) print("PHANTOM_METRICS:" + json.dumps(metrics))
def _should_print_local(spec: TrainSpec) -> bool: def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None:
if not spec.runtime.use_jax: if not events:
return True return
try: period = max(1, int(log_freq))
import jax last_logged_step = -period
for event in sorted(
return int(jax.process_index()) == 0 [evt for evt in events if isinstance(evt, dict)],
except Exception: key=lambda evt: int(evt.get("train/global_step", 0)),
return True ):
step = int(event.get("train/global_step", 0))
if step <= 0 or (step - last_logged_step) < period:
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool: continue
if not spec.runtime.use_jax: log_metrics(event, step=step)
return False last_logged_step = step
try:
import jax
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
except Exception:
return False
def run_train_once( def run_train_once(
@@ -65,10 +59,9 @@ def run_train_once(
extra_tags: Sequence[str], extra_tags: Sequence[str],
) -> dict[str, Any]: ) -> dict[str, Any]:
wandb = get_wandb_module() wandb = get_wandb_module()
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec): if no_wandb or wandb is None:
result = run_train(spec) result = run_train(spec)
if _should_print_local(spec): _print_local_metrics(result.metrics)
_print_local_metrics(result.metrics)
return result.metrics return result.metrics
mode = "offline" if offline else "online" mode = "offline" if offline else "online"
@@ -95,6 +88,7 @@ def run_train_once(
try: try:
result = run_train(spec) result = run_train(spec)
_log_train_events(result.events, spec.runtime.log_freq)
metrics = result.metrics metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step) log_metrics(metrics, step=step)
@@ -122,6 +116,7 @@ def run_with_active_sweep_run(
) )
update_run_config({**spec.to_flat_dict(), **metadata}) update_run_config({**spec.to_flat_dict(), **metadata})
result = run_train(spec) result = run_train(spec)
_log_train_events(result.events, spec.runtime.log_freq)
metrics = result.metrics metrics = result.metrics
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
log_metrics(metrics, step=step) log_metrics(metrics, step=step)

View File

@@ -81,44 +81,6 @@
"command": "bash scripts/nx_research.sh docker-train-publish", "command": "bash scripts/nx_research.sh docker-train-publish",
"cwd": "." "cwd": "."
} }
},
"train-tpu-pod": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh train-tpu-pod",
"cwd": "."
}
},
"train-tpu-vm-prepare": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh train-tpu-vm-prepare",
"cwd": "."
}
},
"train-tpu-vm-run": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh train-tpu-vm-run",
"cwd": "."
}
},
"train-tpu-vm": {
"executor": "nx:run-commands",
"dependsOn": [
"train-tpu-vm-prepare"
],
"options": {
"command": "bash scripts/nx_research.sh train-tpu-vm-run",
"cwd": "."
}
},
"train-tpu-vm-sweep": {
"executor": "nx:run-commands",
"options": {
"command": "bash scripts/nx_research.sh train-tpu-vm-sweep",
"cwd": "."
}
} }
}, },
"tags": [ "tags": [

View File

@@ -106,11 +106,6 @@ class OptimizerSpec:
eps_decay: float = 0.9995 eps_decay: float = 0.9995
arch: str = "small" arch: str = "small"
activation: str = "relu" activation: str = "relu"
jax_num_envs: int = 16
jax_num_steps: int = 128
jax_num_minibatches: int = 4
jax_update_epochs: int = 4
jax_anneal_lr: bool = True
vf_coef: float = 0.5 vf_coef: float = 0.5
max_grad_norm: float = 0.5 max_grad_norm: float = 0.5
@@ -125,7 +120,6 @@ class RuntimeSpec:
checkpoint_interval: int = 200_000 checkpoint_interval: int = 200_000
model_dir: str = "engine/models" model_dir: str = "engine/models"
log_freq: int = 100 log_freq: int = 100
use_jax: bool = False
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -156,7 +150,6 @@ class TrainSpec:
"model_dir": self.runtime.model_dir, "model_dir": self.runtime.model_dir,
"backend": self.runtime.backend, "backend": self.runtime.backend,
"device": self.runtime.device, "device": self.runtime.device,
"use_jax": self.runtime.use_jax,
"checkpoint_interval": self.runtime.checkpoint_interval, "checkpoint_interval": self.runtime.checkpoint_interval,
"n_products": self.env.n_products, "n_products": self.env.n_products,
"N": self.env.n_sessions, "N": self.env.n_sessions,
@@ -197,11 +190,6 @@ class TrainSpec:
"eps_decay": self.optimizer.eps_decay, "eps_decay": self.optimizer.eps_decay,
"arch": self.optimizer.arch, "arch": self.optimizer.arch,
"activation": self.optimizer.activation, "activation": self.optimizer.activation,
"jax_num_envs": self.optimizer.jax_num_envs,
"jax_num_steps": self.optimizer.jax_num_steps,
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
"jax_update_epochs": self.optimizer.jax_update_epochs,
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
"vf_coef": self.optimizer.vf_coef, "vf_coef": self.optimizer.vf_coef,
"max_grad_norm": self.optimizer.max_grad_norm, "max_grad_norm": self.optimizer.max_grad_norm,
"robust_eval_enabled": self.eval.robust_eval_enabled, "robust_eval_enabled": self.eval.robust_eval_enabled,
@@ -223,14 +211,11 @@ class TrainSpec:
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto")) base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
) )
requested_jax = _truthy(base.get("use_jax")) or _truthy( backend = str(base.get("backend", "sb3")).lower()
runtime_env.get("PHANTOM_USE_JAX")
)
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
if backend == "auto": if backend == "auto":
backend = "jax" if requested_jax else "sb3" backend = "sb3"
if backend == "jax": if backend != "sb3":
requested_jax = True backend = "sb3"
no_robust = _truthy(base.get("no_robust")) no_robust = _truthy(base.get("no_robust"))
if no_robust: if no_robust:
@@ -284,11 +269,6 @@ class TrainSpec:
eps_decay=float(base["eps_decay"]), eps_decay=float(base["eps_decay"]),
arch=str(base["arch"]), arch=str(base["arch"]),
activation=str(base["activation"]), activation=str(base["activation"]),
jax_num_envs=int(base["jax_num_envs"]),
jax_num_steps=int(base["jax_num_steps"]),
jax_num_minibatches=int(base["jax_num_minibatches"]),
jax_update_epochs=int(base["jax_update_epochs"]),
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
vf_coef=float(base["vf_coef"]), vf_coef=float(base["vf_coef"]),
max_grad_norm=float(base["max_grad_norm"]), max_grad_norm=float(base["max_grad_norm"]),
), ),
@@ -301,7 +281,6 @@ class TrainSpec:
checkpoint_interval=int(base["checkpoint_interval"]), checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]), model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]), log_freq=int(base["log_freq"]),
use_jax=requested_jax,
), ),
eval=EvalSpec( eval=EvalSpec(
eval_freq=int(base["eval_freq"]), eval_freq=int(base["eval_freq"]),

View File

@@ -1,93 +0,0 @@
method: bayes
metric:
name: objective/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
# fixed: always use JAX backend so TPU chips are actually exercised
use_jax:
value: true
# all four algos have JAX implementations
algo:
values: [ppo, a2c, dqn, qtable]
total_timesteps:
values: [50000, 80000, 120000]
checkpoint_interval:
value: 200000
seed:
values: [13, 42, 77]
n_products:
values: [8, 10, 12]
# COI framework parameters -- primary research variables
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
robust_points:
values: [3, 5, 7]
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# shared hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
# JAX parallelism -- key lever for TPU throughput
jax_num_envs:
values: [8, 16, 32]
jax_num_steps:
values: [64, 128, 256]
jax_num_minibatches:
values: [2, 4, 8]
jax_update_epochs:
values: [2, 4, 8]
# PPO/A2C specific
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]
# DQN specific
buffer_size:
values: [20000, 50000, 100000]
batch_size:
values: [128, 256, 512]
learning_starts:
values: [500, 1000, 3000]
exploration_fraction:
values: [0.1, 0.2, 0.3]
exploration_final_eps:
values: [0.01, 0.03, 0.05]
# QTable specific
q_lr:
values: [0.03, 0.05, 0.1, 0.2]
eps_end:
values: [0.02, 0.05, 0.1]
eps_decay:
values: [0.999, 0.9995, 0.9999]
# action space
action_levels:
values: [7, 9, 11]
action_scale_low:
values: [0.75, 0.8, 0.85]
action_scale_high:
values: [1.15, 1.2, 1.25]

View File

@@ -1,64 +0,0 @@
method: bayes
metric:
name: objective/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
use_jax:
value: true
# pmap requires all workers to compile the same computation graph shape,
# so structural params are fixed -- only research/scalar params are swept
algo:
values: [ppo, a2c]
jax_num_envs:
value: 32
jax_num_steps:
value: 128
jax_num_minibatches:
value: 4
jax_update_epochs:
value: 4
total_timesteps:
value: 100000
checkpoint_interval:
value: 200000
n_products:
value: 10
action_levels:
value: 9
# research parameters -- primary sweep targets
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# training hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]

View File

@@ -7,14 +7,6 @@ from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
from .spec import TrainSpec from .spec import TrainSpec
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _parse_tags(raw: str | None) -> list[str]: def _parse_tags(raw: str | None) -> list[str]:
if raw is None: if raw is None:
return [] return []
@@ -55,7 +47,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--group", type=str) parser.add_argument("--group", type=str)
parser.add_argument("--tags", type=str) parser.add_argument("--tags", type=str)
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto") parser.add_argument("--backend", choices=["auto", "sb3"], default="auto")
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"]) parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
parser.add_argument("--seed", type=int) parser.add_argument("--seed", type=int)
parser.add_argument("--total-timesteps", type=int) parser.add_argument("--total-timesteps", type=int)
@@ -111,13 +103,6 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--eval-freq", type=int) parser.add_argument("--eval-freq", type=int)
parser.add_argument("--eval-episodes", type=int) parser.add_argument("--eval-episodes", type=int)
parser.add_argument("--jax", action="store_true")
parser.add_argument("--jax-num-envs", type=int)
parser.add_argument("--jax-num-steps", type=int)
parser.add_argument("--jax-num-minibatches", type=int)
parser.add_argument("--jax-update-epochs", type=int)
parser.add_argument("--jax-anneal-lr", type=str)
parser.add_argument("--sweep-agent", action="store_true") parser.add_argument("--sweep-agent", action="store_true")
parser.add_argument("--sweep-id", type=str) parser.add_argument("--sweep-id", type=str)
parser.add_argument("--count", type=int, default=0) parser.add_argument("--count", type=int, default=0)
@@ -127,9 +112,6 @@ def _build_parser() -> argparse.ArgumentParser:
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
jax_anneal_lr = (
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
)
backend = None if args.backend == "auto" else args.backend backend = None if args.backend == "auto" else args.backend
overrides = { overrides = {
@@ -185,12 +167,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"max_grad_norm": args.max_grad_norm, "max_grad_norm": args.max_grad_norm,
"eval_freq": args.eval_freq, "eval_freq": args.eval_freq,
"eval_episodes": args.eval_episodes, "eval_episodes": args.eval_episodes,
"use_jax": args.jax or None,
"jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"jax_anneal_lr": jax_anneal_lr,
} }
return {key: value for key, value in overrides.items() if value is not None} return {key: value for key, value in overrides.items() if value is not None}

View File

@@ -12,17 +12,14 @@ class TrainResult:
spec: TrainSpec spec: TrainSpec
metrics: dict[str, Any] metrics: dict[str, Any]
artifacts: dict[str, str] artifacts: dict[str, str]
events: list[dict[str, Any]]
def run_train(spec: TrainSpec) -> TrainResult: def run_train(spec: TrainSpec) -> TrainResult:
cfg = spec.to_flat_dict() cfg = spec.to_flat_dict()
algo = spec.algorithm.name algo = spec.algorithm.name
if spec.runtime.use_jax or spec.runtime.backend == "jax": if algo == "qtable":
from .backends.jax import train_jax_backend
_, raw_metrics = train_jax_backend(cfg)
elif algo == "qtable":
from .backends.qtable import train_qtable from .backends.qtable import train_qtable
_, raw_metrics = train_qtable(cfg) _, raw_metrics = train_qtable(cfg)
@@ -31,10 +28,13 @@ def run_train(spec: TrainSpec) -> TrainResult:
_, raw_metrics = train_sb3(cfg) _, raw_metrics = train_sb3(cfg)
events_raw = raw_metrics.pop("_train_events", [])
events = [evt for evt in events_raw if isinstance(evt, dict)]
metrics = canonicalize_metrics(raw_metrics, spec) metrics = canonicalize_metrics(raw_metrics, spec)
artifacts: dict[str, str] = {} artifacts: dict[str, str] = {}
model_path = raw_metrics.get("model/path") model_path = raw_metrics.get("model/path")
if isinstance(model_path, str): if isinstance(model_path, str):
artifacts["model/path"] = model_path artifacts["model/path"] = model_path
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts) return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts, events=events)

View File

@@ -108,49 +108,6 @@ PY
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}" image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" . docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
docker push "$image_ref:gpu-latest" docker push "$image_ref:gpu-latest"
docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" .
docker push "$image_ref:tpu-latest"
;;
train-tpu-pod)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh"
;;
train-tpu-vm-prepare)
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
TPU_NAME="$TPU_NAME" \
TPU_ZONE="${TPU_ZONE:-us-central2-b}" \
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \
LOCAL_REPO_DIR="$PWD" \
REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \
sh scripts/tpu_sync_repo.sh
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
;;
train-tpu-vm-run)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000"
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh"
;;
train-tpu-vm-sweep)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
args=(
--sweep-id "$SWEEP_ID"
--tpu-name "$TPU_NAME"
--tpu-zone "${TPU_ZONE:-us-central2-b}"
--tpu-project "${TPU_PROJECT:-phantom-trc}"
--tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}"
)
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
args+=(--count "$AGENT_COUNT")
fi
WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}"
;; ;;
*) *)
printf '%s\n' "Unknown research command: $cmd" >&2 printf '%s\n' "Unknown research command: $cmd" >&2

View File

@@ -1,32 +0,0 @@
#!/usr/bin/env sh
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
# Authenticates with Artifact Registry using the VM's service account metadata token,
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
# Required env vars: WANDB_API_KEY, SWEEP_ID
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
set -eu
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
AGENT_COUNT="${AGENT_COUNT:-1}"
# use VM service account — no manual key needed on the pod
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
--password-stdin https://us-central1-docker.pkg.dev
sudo docker pull "$IMAGE"
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
# --network host lets JAX reach the other pod workers for distributed init
sudo docker run --rm \
--privileged \
--network host \
--volume /dev:/dev \
-e WANDB_API_KEY="$WANDB_API_KEY" \
-e SWEEP_ID="$SWEEP_ID" \
-e AGENT_COUNT="$AGENT_COUNT" \
"$IMAGE"

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env sh
set -eu
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
CLEANUP_LIST=true
cleanup() {
if [ "$CLEANUP_LIST" = "true" ]; then
rm -f "$FILE_LIST"
fi
}
trap cleanup EXIT
if [ ! -d "$LOCAL_REPO_DIR" ]; then
echo "local repo directory not found: $LOCAL_REPO_DIR"
exit 1
fi
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
python3 - "$FILE_LIST" <<'PY'
import sys
from pathlib import Path
file_list = Path(sys.argv[1])
skip_prefixes = (
"wandb/",
".venv/",
"venv/",
"node_modules/",
".next/",
".turbo/",
"__pycache__/",
".mypy_cache/",
".pytest_cache/",
".ruff_cache/",
"paper/build/",
"tests/e2e/test-results/",
)
rows = file_list.read_text().splitlines()
kept = [
row
for row in rows
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
]
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
PY
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
else
tar \
--exclude-vcs \
--exclude=".venv" --exclude="*/.venv" \
--exclude="venv" --exclude="*/venv" \
--exclude="node_modules" --exclude="*/node_modules" \
--exclude=".next" --exclude="*/.next" \
--exclude=".turbo" --exclude="*/.turbo" \
--exclude="__pycache__" --exclude="*/__pycache__" \
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
--exclude="wandb" --exclude="*/wandb" \
--exclude="paper/build" \
--exclude="tests/e2e/test-results" \
-czf "$ARCHIVE_PATH" \
-C "$LOCAL_REPO_DIR" .
fi
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
rm -f "$ARCHIVE_PATH"

View File

@@ -1,211 +0,0 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import gc
import json
import os
import re
import shlex
import shutil
import subprocess
import time
import resource
from pathlib import Path
import wandb
CLI_MAP: dict[str, str] = {
"algo": "--algo",
"total_timesteps": "--total-timesteps",
"alpha": "--alpha",
"N": "--N",
"n_products": "--n-products",
"lambda_coi": "--lambda-coi",
"info_value": "--info-value",
"robust_radius": "--robust-radius",
"robust_points": "--robust-points",
"no_robust": "--no-robust",
"learning_rate": "--learning-rate",
"gamma": "--gamma",
"gae_lambda": "--gae-lambda",
"clip_range": "--clip-range",
"ent_coef": "--ent-coef",
"revenue_weight": "--revenue-weight",
"max_steps": "--max-steps",
"margin_floor": "--margin-floor",
"margin_floor_patience": "--margin-floor-patience",
"arch": "--arch",
"activation": "--activation",
"jax_num_envs": "--jax-num-envs",
"jax_num_steps": "--jax-num-steps",
"jax_num_minibatches": "--jax-num-minibatches",
"jax_update_epochs": "--jax-update-epochs",
"jax_anneal_lr": "--jax-anneal-lr",
"checkpoint_interval": "--checkpoint-interval",
"action_levels": "--action-levels",
"action_scale_low": "--action-scale-low",
"action_scale_high": "--action-scale-high",
}
def _to_cli_args(cfg: dict) -> str:
parts: list[str] = ["--jax", "--no-wandb"]
for key, flag in CLI_MAP.items():
if key not in cfg:
continue
value = cfg[key]
if value is None:
continue
if isinstance(value, bool):
if key == "jax_anneal_lr":
parts.extend([flag, "true" if value else "false"])
elif value:
parts.append(flag)
continue
parts.extend([flag, str(value)])
return " ".join(shlex.quote(p) for p in parts)
_SENTINEL = "PHANTOM_METRICS:"
def _raise_nofile_limit(min_soft: int = 8192) -> None:
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
target = min(hard, max(soft, min_soft))
if target > soft:
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
except Exception:
return
def _extract_metrics(output: str) -> dict:
# fast path: look for the dedicated sentinel line emitted by run_local
for line in output.splitlines():
if line.startswith(_SENTINEL):
try:
return json.loads(line[len(_SENTINEL) :])
except Exception:
break
# fallback: scan for any JSON block containing eval/sweep keys;
# use greedy match to capture the largest possible block first
for block in re.findall(r"\{[^{}]*\}", output):
try:
obj = json.loads(block)
except Exception:
continue
if isinstance(obj, dict) and (
"objective/score" in obj
or "eval/reward_mean" in obj
or "sweep/score" in obj
):
return obj
return {}
def main() -> None:
_raise_nofile_limit()
p = argparse.ArgumentParser(
description="Run W&B sweep where each trial uses full TPU pod"
)
p.add_argument("--sweep-id", required=True)
p.add_argument("--tpu-name", required=True)
p.add_argument("--tpu-zone", default="us-central2-b")
p.add_argument("--tpu-project", default="phantom-trc")
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
p.add_argument("--count", type=int, default=0)
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
args = p.parse_args()
workdir = Path(args.workdir).resolve()
env = os.environ.copy()
wandb_root = workdir / ".wandb-agent"
wandb_root.mkdir(parents=True, exist_ok=True)
prepare_cmd = [
"make",
"train.tpu.vm.prepare",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
prepare = subprocess.run(
prepare_cmd,
cwd=workdir,
env=env,
text=True,
capture_output=False,
check=False,
)
if prepare.returncode != 0:
raise RuntimeError("Failed to prepare TPU workers for sweep")
def run_trial() -> None:
run = None
trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}"
trial_wandb_dir.mkdir(parents=True, exist_ok=True)
try:
run = wandb.init(dir=str(trial_wandb_dir))
cfg = dict(wandb.config)
cli_args = _to_cli_args(cfg)
env_trial = dict(env)
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
env_trial["WANDB_DIR"] = str(trial_wandb_dir)
env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache")
env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data")
cmd = [
"make",
"train.tpu.vm.run",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
proc = subprocess.run(
cmd,
cwd=workdir,
env=env_trial,
text=True,
capture_output=True,
check=False,
)
if proc.stdout:
print(proc.stdout)
if proc.stderr:
print(proc.stderr)
if proc.returncode != 0:
if run is not None:
run.summary["runner/exit_code"] = proc.returncode
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
metrics = _extract_metrics(proc.stdout)
if metrics:
wandb.log(metrics)
for k, v in metrics.items():
run.summary[k] = v
run.summary["runner/exit_code"] = 0
except Exception:
time.sleep(2)
raise
finally:
if run is not None and wandb.run is not None:
wandb.finish()
shutil.rmtree(trial_wandb_dir, ignore_errors=True)
gc.collect()
wandb.agent(
args.sweep_id,
function=run_trial,
count=args.count if args.count > 0 else None,
)
if __name__ == "__main__":
main()

View File

@@ -1,43 +0,0 @@
#!/usr/bin/env sh
set -eu
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
if [ ! -d "$REPO_DIR" ]; then
echo "repo directory not found: $REPO_DIR"
exit 1
fi
cd "$REPO_DIR"
if [ -d "wandb" ]; then
rm -rf wandb
fi
# keep install idempotent and avoid re-installing jax/libtpu each run
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
$PYTHON_BIN -m pip install -r requirements.txt
fi
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
if [ -f "engine/jax/requirements.txt" ]; then
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
fi
$PYTHON_BIN -m pip install -U $EXTRA_PIP
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
$PYTHON_BIN -m pip install -U wandb
fi
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
export WANDB_API_KEY
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
fi
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb