diff --git a/Makefile b/Makefile index 22a67db..6e0db89 100644 --- a/Makefile +++ b/Makefile @@ -27,10 +27,6 @@ AGENT_LOOP ?= 1 RETRY_SECONDS ?= 20 TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer -TPU_NAME ?= -TPU_ZONE ?= us-central2-b -TPU_PROJECT ?= phantom-trc -TPU_REPO_DIR ?= /tmp/PHANTOM SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a @@ -38,7 +34,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" + @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | stats.lines" @echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish" @echo "" @echo "Build general public version:" @@ -137,26 +133,6 @@ wordcount: docker.train.publish: @TRAIN_IMAGE_REF="$(TRAIN_IMAGE_REF)" $(NX) run research:docker-train-publish -.PHONY: train.tpu.pod -train.tpu.pod: - @TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-pod - -.PHONY: train.tpu.vm.prepare -train.tpu.vm.prepare: - @TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" $(NX) run research:train-tpu-vm-prepare - -.PHONY: train.tpu.vm.run -train.tpu.vm.run: - @TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm-run - -.PHONY: train.tpu.vm -train.tpu.vm: - @TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm - -.PHONY: train.tpu.vm.sweep -train.tpu.vm.sweep: - @TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-vm-sweep - .PHONY: backend.server backend.provider backend.worker platform.up platform.down platform.logs backend.server: @$(NX) run backend-server:dev diff --git a/docker/Trainer.dockerfile b/docker/Trainer.dockerfile index df50fed..f9cc73f 100644 --- a/docker/Trainer.dockerfile +++ b/docker/Trainer.dockerfile @@ -7,36 +7,9 @@ WORKDIR /app COPY docker/trainer.requirements.txt /tmp/requirements.txt RUN pip install --no-cache-dir -r /tmp/requirements.txt -# Optional for JAX-on-GPU workflows. -ARG INSTALL_JAX_GPU=false -RUN if [ "${INSTALL_JAX_GPU}" = "true" ]; then \ - pip install --no-cache-dir "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \ - fi - COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint COPY engine /app/engine -ENV PYTHONPATH=/app \ - XLA_PYTHON_CLIENT_PREALLOCATE=false - -ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"] - - -FROM python:3.11-slim AS tpu - -WORKDIR /app - -COPY docker/trainer.requirements.txt /tmp/requirements.txt -RUN pip install --no-cache-dir -r /tmp/requirements.txt - -RUN pip install --no-cache-dir "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - -COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint -COPY engine /app/engine - -ENV PYTHONPATH=/app \ - PHANTOM_USE_JAX=1 \ - PHANTOM_DEFAULT_AGENT_ARGS="--jax" \ - XLA_PYTHON_CLIENT_PREALLOCATE=false +ENV PYTHONPATH=/app ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"] diff --git a/docker/trainer.requirements.txt b/docker/trainer.requirements.txt index c47ed11..2768cae 100644 --- a/docker/trainer.requirements.txt +++ b/docker/trainer.requirements.txt @@ -5,9 +5,3 @@ gymnasium>=0.29.0 stable-baselines3>=2.2.0 tensorboard>=2.15.0 wandb>=0.17.0 -tensorflow-probability==0.24.0 -flax==0.10.7 -optax==0.2.7 -distrax==0.1.5 -orbax-checkpoint==0.11.32 -chex==0.1.90 diff --git a/engine/backends/__init__.py b/engine/backends/__init__.py index 014450a..95c4786 100644 --- a/engine/backends/__init__.py +++ b/engine/backends/__init__.py @@ -1 +1 @@ -__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"] +__all__ = ["evaluate", "make_env", "train_qtable", "train_sb3"] diff --git a/engine/backends/jax.py b/engine/backends/jax.py deleted file mode 100644 index 980c01f..0000000 --- a/engine/backends/jax.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from typing import Any, Mapping - -from ..jax import JAX_AVAILABLE - - -def train_jax_backend( - cfg: Mapping[str, Any], -) -> tuple[dict[str, Any], dict[str, float | int | str]]: - if not JAX_AVAILABLE: - raise ImportError( - "JAX backend requested but JAX is not installed. " - "Install engine/jax/requirements.txt and jax[tpu] for TPU runs." - ) - from ..jax.train import train_jax - - return train_jax(dict(cfg)) diff --git a/engine/backends/qtable.py b/engine/backends/qtable.py index 9a6e3fe..754cfa8 100644 --- a/engine/backends/qtable.py +++ b/engine/backends/qtable.py @@ -7,7 +7,9 @@ import numpy as np from .common import evaluate, make_env -def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]: +def train_qtable( + cfg: Mapping[str, Any], +) -> tuple[object, dict[str, Any]]: from ..lib.discrete import EventQTable np.random.seed(int(cfg["seed"])) @@ -26,8 +28,19 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int] total_revenue = 0.0 steps = 0 epsilon = float(cfg["eps_start"]) + log_freq = max(1, int(cfg.get("log_freq", 100))) obs, _ = env.reset(seed=int(cfg["seed"])) + interval_sums = { + "reward": 0.0, + "revenue": 0.0, + "agent_prob": 0.0, + "alpha_adv": 0.0, + "coi_leakage": 0.0, + } + interval_count = 0 + train_events: list[dict[str, float | int]] = [] + for _ in range(int(cfg["total_timesteps"])): action, state = agent.act(obs, epsilon) nxt, reward, term, trunc, info = env.step(action) @@ -35,18 +48,57 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int] agent.update(state, action, float(reward), agent.encode(nxt), done) total_reward += float(reward) - total_revenue += float(info.get("economics", {}).get("revenue", 0.0)) + revenue = float(info.get("economics", {}).get("revenue", 0.0)) + total_revenue += revenue steps += 1 + interval_sums["reward"] += float(reward) + interval_sums["revenue"] += revenue + interval_sums["agent_prob"] += float(info.get("agent_prob", 0.0)) + interval_sums["alpha_adv"] += float(info.get("alpha_adv", 0.0)) + interval_sums["coi_leakage"] += float(info.get("coi_leakage", 0.0)) + interval_count += 1 + + if steps % log_freq == 0 and interval_count > 0: + denom = float(interval_count) + train_events.append( + { + "train/reward_mean": interval_sums["reward"] / denom, + "train/revenue_mean": interval_sums["revenue"] / denom, + "train/agent_prob": interval_sums["agent_prob"] / denom, + "train/alpha_adv": interval_sums["alpha_adv"] / denom, + "train/coi_leakage": interval_sums["coi_leakage"] / denom, + "train/epsilon": float(epsilon), + "train/global_step": int(steps), + } + ) + interval_sums = {key: 0.0 for key in interval_sums} + interval_count = 0 + epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"])) obs = env.reset()[0] if done else nxt - metrics: dict[str, float | int] = { + if interval_count > 0: + denom = float(interval_count) + train_events.append( + { + "train/reward_mean": interval_sums["reward"] / denom, + "train/revenue_mean": interval_sums["revenue"] / denom, + "train/agent_prob": interval_sums["agent_prob"] / denom, + "train/alpha_adv": interval_sums["alpha_adv"] / denom, + "train/coi_leakage": interval_sums["coi_leakage"] / denom, + "train/epsilon": float(epsilon), + "train/global_step": int(steps), + } + ) + + metrics: dict[str, Any] = { "train/reward_mean": total_reward / max(steps, 1), "train/revenue_mean": total_revenue / max(steps, 1), "train/epsilon": float(epsilon), "train/global_step": int(cfg["total_timesteps"]), } metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"]))) + metrics["_train_events"] = train_events env.close() eval_env.close() diff --git a/engine/backends/sb3.py b/engine/backends/sb3.py index ad17e0b..52dbd87 100644 --- a/engine/backends/sb3.py +++ b/engine/backends/sb3.py @@ -4,9 +4,7 @@ import json from pathlib import Path from typing import Any, Mapping -from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback -from ..telemetry.wandb import get_wandb_module -from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint +from ..lib.callbacks import MetricsCallback from .common import evaluate, make_env @@ -52,21 +50,6 @@ def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]: return kwargs -def _sb3_model_cls(algo: str): - try: - from stable_baselines3 import A2C, DQN, PPO - except ImportError as exc: - raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc - - if algo == "ppo": - return PPO - if algo == "a2c": - return A2C - if algo == "dqn": - return DQN - raise ValueError(f"unsupported algo '{algo}'") - - def build_model(cfg: Mapping[str, Any], env: Any): try: from stable_baselines3 import A2C, DQN, PPO @@ -132,29 +115,7 @@ def build_model(cfg: Mapping[str, Any], env: Any): raise ValueError(f"unsupported algo '{algo}'") -def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any): - wandb = get_wandb_module() - if wandb is None or wandb.run is None: - return model - - sweep_id = getattr(wandb.run, "sweep_id", None) - artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id) - checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip" - restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file) - if restored is None: - return model - - checkpoint_path, metadata = restored - resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load( - checkpoint_path.as_posix(), - env=env, - ) - resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0))) - resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step) - return resumed - - -def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]: +def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]: try: from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor @@ -182,15 +143,10 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s except Exception: pass - model = _maybe_resume_model(cfg, env, model) - - callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))] - callbacks.append( - CheckpointArtifactCallback( - dict(cfg), - interval=int(cfg.get("checkpoint_interval", 10_000)), - ) + metrics_callback = MetricsCallback( + log_histograms=False, log_freq=int(cfg["log_freq"]) ) + callbacks = [metrics_callback] callbacks.append( EvalCallback( eval_env, @@ -215,13 +171,14 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s model_path = model_dir / f"phantom_{cfg['algo']}" model.save(str(model_path)) - metrics: dict[str, float | int | str] = evaluate( + metrics: dict[str, Any] = evaluate( model, eval_env, int(cfg["eval_episodes"]), ) metrics["train/global_step"] = int(model.num_timesteps) metrics["model/path"] = str(model_path.with_suffix(".zip")) + metrics["_train_events"] = list(metrics_callback.events) env.close() eval_env.close() diff --git a/engine/jax/__init__.py b/engine/jax/__init__.py deleted file mode 100644 index 8b6f740..0000000 --- a/engine/jax/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""JAX-compatible training and environment modules for PHANTOM.""" - -from __future__ import annotations - -try: - import jax # noqa: F401 - import jax.numpy as jnp # noqa: F401 - - JAX_AVAILABLE = True -except ImportError: - JAX_AVAILABLE = False - -__all__ = ["JAX_AVAILABLE"] diff --git a/engine/jax/checkpoint.py b/engine/jax/checkpoint.py deleted file mode 100644 index c75c6bc..0000000 --- a/engine/jax/checkpoint.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Orbax checkpoint helpers for JAX training runs.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -try: - import orbax.checkpoint as ocp - - HAS_ORBAX = True -except ImportError: - HAS_ORBAX = False - - -def _require_orbax() -> None: - if not HAS_ORBAX: - raise ImportError( - "orbax-checkpoint is required for checkpoint support. " - "Install engine/jax/requirements.txt first." - ) - - -def create_manager(directory: str | Path, max_to_keep: int = 5): - _require_orbax() - root = Path(directory) - root.mkdir(parents=True, exist_ok=True) - options = ocp.CheckpointManagerOptions( - max_to_keep=max(1, int(max_to_keep)), create=True - ) - return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options) - - -def save(manager, *, step: int, payload: Any) -> bool: - _require_orbax() - return bool(manager.save(int(step), payload)) - - -def latest_step(manager) -> int | None: - _require_orbax() - return manager.latest_step() - - -def restore(manager, *, target: Any, step: int | None = None) -> Any: - _require_orbax() - step_to_restore = manager.latest_step() if step is None else int(step) - if step_to_restore is None: - return target - return manager.restore(step_to_restore, items=target) diff --git a/engine/jax/env.py b/engine/jax/env.py deleted file mode 100644 index 8ecafd1..0000000 --- a/engine/jax/env.py +++ /dev/null @@ -1,304 +0,0 @@ -"""JAX-native PHANTOM environment with robust contamination step.""" - -from __future__ import annotations - -from typing import NamedTuple - -try: - import jax - import jax.numpy as jnp -except ImportError as exc: # pragma: no cover - raise ImportError("engine.jax.env requires JAX") from exc - -from .primitives import ( - _sample_sessions_jax, - agent_probability_from_kl, - batch_kl, - compute_session_transitions, - load_transition_data, - purchase_flags, - reward_with_coi_penalty, - revenue_from_demand, - weighted_demand, -) - - -class EnvParams(NamedTuple): - n_products: int - n_sessions: int - max_episode_steps: int - max_session_steps: int - price_low: float - price_high: float - lambda_coi: float - info_value: float - eta_ux: float - robust_radius: float - margin_floor: float - margin_floor_patience: int - action_scales: jax.Array - alpha_nominal: float - alpha_candidates: jax.Array - human_T: jax.Array - agent_T: jax.Array - terminal_mask: jax.Array - purchase_mask: jax.Array - event_weights: jax.Array - start_idx: int - term_idx: int - - -class EnvState(NamedTuple): - prices: jax.Array - demand: jax.Array - step_count: jax.Array - low_margin_streak: jax.Array - last_agent_prob: jax.Array - last_alpha_adv: jax.Array - - -class CandidateEval(NamedTuple): - reward: jax.Array - revenue: jax.Array - demand: jax.Array - agent_prob: jax.Array - leakage: jax.Array - discount: jax.Array - ux_penalty: jax.Array - n_purchases: jax.Array - n_agents: jax.Array - - -def make_env_params( - *, - n_products: int, - alpha: float, - n_sessions: int, - lambda_coi: float, - robust_radius: float, - robust_points: int, - info_value: float, - eta_ux: float = 0.5, - action_levels: int, - action_scale_low: float, - action_scale_high: float, - price_low: float, - price_high: float, - max_episode_steps: int, - max_session_steps: int = 40, - margin_floor: float = 0.05, - margin_floor_patience: int = 5, - prefer_behavior_data: bool = True, -) -> EnvParams: - transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax() - if robust_radius <= 0.0 or robust_points <= 1: - alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32) - else: - lo = max(0.0, float(alpha) - float(robust_radius)) - hi = min(1.0, float(alpha) + float(robust_radius)) - alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32) - - action_scales = jnp.linspace( - float(action_scale_low), - float(action_scale_high), - int(action_levels), - dtype=jnp.float32, - ) - return EnvParams( - n_products=int(n_products), - n_sessions=int(n_sessions), - max_episode_steps=int(max_episode_steps), - max_session_steps=int(max_session_steps), - price_low=float(price_low), - price_high=float(price_high), - lambda_coi=float(lambda_coi), - info_value=float(info_value), - eta_ux=float(eta_ux), - robust_radius=float(robust_radius), - margin_floor=float(margin_floor), - margin_floor_patience=int(margin_floor_patience), - action_scales=action_scales, - alpha_nominal=float(alpha), - alpha_candidates=alpha_candidates, - human_T=jnp.asarray(transition.human_T), - agent_T=jnp.asarray(transition.agent_T), - terminal_mask=jnp.asarray(transition.terminal_mask), - purchase_mask=jnp.asarray(transition.purchase_mask), - event_weights=jnp.asarray(transition.event_weights), - start_idx=int(transition.start_idx), - term_idx=int(transition.term_idx), - ) - - -def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array: - return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)]) - - -def _decode_action( - prices: jax.Array, action: jax.Array, params: EnvParams -) -> jax.Array: - idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1) - scale = params.action_scales[idx] - next_prices = prices * scale - return jnp.clip(next_prices, params.price_low, params.price_high) - - -def _evaluate_candidate( - key: jax.Array, - alpha_candidate: jax.Array, - prices: jax.Array, - ux_volatility: jax.Array, - params: EnvParams, -) -> CandidateEval: - states, products, actors, lengths = _sample_sessions_jax( - key, - params.human_T, - params.agent_T, - params.terminal_mask, - params.start_idx, - params.term_idx, - alpha_candidate, - params.n_products, - params.n_sessions, - params.max_session_steps, - int(params.human_T.shape[0]), - ) - session_trans = compute_session_transitions( - states, lengths, int(params.human_T.shape[0]) - ) - delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T) - agent_probs = agent_probability_from_kl(delta_h, delta_a) - agent_prob = jnp.mean(agent_probs) - - demand = weighted_demand(states, products, params.n_products, params.event_weights) - revenue = revenue_from_demand(prices, demand) - reward, leakage, discount, ux_penalty = reward_with_coi_penalty( - revenue, - agent_prob, - params.lambda_coi, - params.info_value, - params.eta_ux, - ux_volatility, - ) - purchases = purchase_flags(states, params.purchase_mask) - return CandidateEval( - reward=reward, - revenue=revenue, - demand=demand, - agent_prob=agent_prob, - leakage=leakage, - discount=discount, - ux_penalty=ux_penalty, - n_purchases=jnp.sum(purchases.astype(jnp.float32)), - n_agents=jnp.sum(actors.astype(jnp.float32)), - ) - - -def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]: - prices = jax.random.uniform( - key, - shape=(params.n_products,), - minval=params.price_low, - maxval=params.price_high, - ) - demand = jnp.zeros((params.n_products,), dtype=jnp.float32) - state = EnvState( - prices=prices, - demand=demand, - step_count=jnp.asarray(0, dtype=jnp.int32), - low_margin_streak=jnp.asarray(0, dtype=jnp.int32), - last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32), - last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32), - ) - return _flatten_obs(demand, prices), state - - -def step_env( - key: jax.Array, - state: EnvState, - action: jax.Array, - params: EnvParams, -) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]: - prices = _decode_action(state.prices, action, params) - - baseline = jnp.maximum(state.prices, 1.0) - ux_volatility = jnp.where( - state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0 - ) - - n_candidates = params.alpha_candidates.shape[0] - cand_keys = jax.random.split(key, n_candidates) - evals = jax.vmap( - lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params), - in_axes=(0, 0), - )(cand_keys, params.alpha_candidates) - idx = jnp.argmin(evals.reward) - - demand = evals.demand[idx] - reward = evals.reward[idx] - revenue = evals.revenue[idx] - agent_prob = evals.agent_prob[idx] - leakage = evals.leakage[idx] - discount = evals.discount[idx] - ux_penalty = evals.ux_penalty[idx] - n_purchases = evals.n_purchases[idx] - n_agents = evals.n_agents[idx] - alpha_adv = params.alpha_candidates[idx] - - step_count = state.step_count + 1 - avg_price = jnp.maximum(jnp.mean(prices), 1e-6) - avg_margin = (avg_price - params.price_low) / avg_price - next_streak = jnp.where( - avg_margin < params.margin_floor, state.low_margin_streak + 1, 0 - ) - - margin_collapsed = next_streak >= params.margin_floor_patience - done = (step_count >= params.max_episode_steps) | margin_collapsed - - next_state = EnvState( - prices=prices, - demand=demand, - step_count=step_count, - low_margin_streak=next_streak, - last_agent_prob=agent_prob, - last_alpha_adv=alpha_adv, - ) - obs = _flatten_obs(demand, prices) - info = { - "revenue": revenue, - "agent_prob": agent_prob, - "alpha_adv": alpha_adv, - "coi_leakage": leakage, - "coi_discount": discount, - "ux_penalty": ux_penalty, - "volatility": ux_volatility, - "n_purchases": n_purchases, - "n_agents": n_agents, - "avg_margin": avg_margin, - } - return obs, next_state, reward, done, info - - -class PHANTOMJAXEnv: - def __init__(self, params: EnvParams): - self.params = params - - def reset(self, key: jax.Array, params: EnvParams | None = None): - return reset_env(key, self.params if params is None else params) - - def step( - self, - key: jax.Array, - state: EnvState, - action: jax.Array, - params: EnvParams | None = None, - ): - return step_env(key, state, action, self.params if params is None else params) - - def action_space_n(self, params: EnvParams | None = None) -> int: - p = self.params if params is None else params - return int(p.action_scales.shape[0]) - - def observation_dim(self, params: EnvParams | None = None) -> int: - p = self.params if params is None else params - return int(p.n_products * 2) diff --git a/engine/jax/primitives.py b/engine/jax/primitives.py deleted file mode 100644 index e638b32..0000000 --- a/engine/jax/primitives.py +++ /dev/null @@ -1,501 +0,0 @@ -"""JAX-compatible primitives for PHANTOM session simulation and separability.""" - -from __future__ import annotations - -from dataclasses import dataclass -from functools import partial -from typing import Mapping - -import numpy as np - -try: - import jax - import jax.numpy as jnp - - JAX_AVAILABLE = True -except ImportError: - jax = None # type: ignore[assignment] - jnp = np # type: ignore[assignment] - JAX_AVAILABLE = False - - -STATE_START_KEYS = ("session_start", "start") -TERMINAL_EVENT_TOKENS = ( - "session_end", - "end", - "purchase_complete", - "checkout_start", - "checkout", -) -PURCHASE_EVENT_TOKENS = ( - "purchase_complete", - "purchase", - "checkout_start", - "checkout", -) - -CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5} -ACTION_CATEGORIES = { - "cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"}, - "dwell": { - "hover_title", - "hover_paragraph", - "hover_link", - "hover_over_title", - "hover_over_paragraph", - "hover_over_link", - "hover_over_button", - }, - "nav": { - "page_view", - "view_item", - "view", - "learn_more", - "learn_more_about_item", - "view_item_page", - "session_start", - }, - "filter": { - "search", - "filter_date", - "filter_price", - "sort", - "filter_for_date", - "filter_for_price", - "filter_for_amenities", - "sort_change", - }, -} -DEFAULT_ACTION_WEIGHTS = { - action: CATEGORY_WEIGHTS[group] - for group, actions in ACTION_CATEGORIES.items() - for action in actions -} - - -@dataclass(frozen=True) -class TransitionData: - """Dense transition kernels and per-state metadata.""" - - human_T: np.ndarray - agent_T: np.ndarray - terminal_mask: np.ndarray - purchase_mask: np.ndarray - event_weights: np.ndarray - event_names: tuple[str, ...] - start_idx: int - term_idx: int - - def to_jax(self) -> "TransitionData": - if not JAX_AVAILABLE: - return self - return TransitionData( - human_T=jnp.asarray(self.human_T), - agent_T=jnp.asarray(self.agent_T), - terminal_mask=jnp.asarray(self.terminal_mask), - purchase_mask=jnp.asarray(self.purchase_mask), - event_weights=jnp.asarray(self.event_weights), - event_names=self.event_names, - start_idx=int(self.start_idx), - term_idx=int(self.term_idx), - ) - - -@dataclass(frozen=True) -class SessionBatch: - states: np.ndarray - products: np.ndarray - actors: np.ndarray - lengths: np.ndarray - - -def _event_weight(name: str) -> float: - if name in DEFAULT_ACTION_WEIGHTS: - return float(DEFAULT_ACTION_WEIGHTS[name]) - if name.startswith("hover"): - return float(CATEGORY_WEIGHTS["dwell"]) - if name.startswith("filter") or name in {"search", "sort", "sort_change"}: - return float(CATEGORY_WEIGHTS["filter"]) - if name.startswith("add") or name in { - "checkout", - "checkout_start", - "purchase", - "remove_item", - "purchase_complete", - }: - return float(CATEGORY_WEIGHTS["cart"]) - if any(token in name for token in TERMINAL_EVENT_TOKENS): - return 0.0 - return float(CATEGORY_WEIGHTS["nav"]) - - -def _is_terminal(name: str) -> bool: - return any(token in name for token in TERMINAL_EVENT_TOKENS) - - -def _is_purchase(name: str) -> bool: - return any(token in name for token in PURCHASE_EVENT_TOKENS) - - -def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]: - names: set[str] = set() - for trans in transitions: - for src, dsts in trans.items(): - names.add(src) - names.update(dsts.keys()) - names.discard("__terminal__") - return tuple(sorted(names)) - - -def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray: - row_sums = matrix.sum(axis=1, keepdims=True) - dead_rows = np.isclose(row_sums.squeeze(-1), 0.0) - if np.any(dead_rows): - matrix[dead_rows] = 0.0 - matrix[dead_rows, term_idx] = 1.0 - row_sums = matrix.sum(axis=1, keepdims=True) - return matrix / np.maximum(row_sums, 1e-8) - - -def _dense_from_dict( - transitions: Mapping[str, Mapping[str, float]], - event_to_idx: Mapping[str, int], - term_idx: int, -) -> np.ndarray: - n_states = len(event_to_idx) - matrix = np.zeros((n_states, n_states), dtype=np.float32) - for src, dsts in transitions.items(): - i = event_to_idx.get(src) - if i is None: - continue - for dst, prob in dsts.items(): - j = event_to_idx.get(dst) - if j is None: - continue - matrix[i, j] += float(prob) - return _normalize_rows(matrix, term_idx) - - -def compile_transition_data( - human_transitions: Mapping[str, Mapping[str, float]], - agent_transitions: Mapping[str, Mapping[str, float]], -) -> TransitionData: - event_names = _collect_events(human_transitions, agent_transitions) - if not event_names: - return fallback_transition_data() - - event_names = tuple([*event_names, "__terminal__"]) - term_idx = len(event_names) - 1 - event_to_idx = {name: i for i, name in enumerate(event_names)} - - human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx) - agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx) - - terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool) - purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool) - event_weights = np.array( - [_event_weight(name) for name in event_names], dtype=np.float32 - ) - - terminal_mask[term_idx] = True - - for idx, is_term in enumerate(terminal_mask): - if not is_term: - continue - human_T[idx] = 0.0 - agent_T[idx] = 0.0 - human_T[idx, idx] = 1.0 - agent_T[idx, idx] = 1.0 - - start_idx = 0 - for key in STATE_START_KEYS: - if key in event_to_idx: - start_idx = int(event_to_idx[key]) - break - - return TransitionData( - human_T=human_T, - agent_T=agent_T, - terminal_mask=terminal_mask, - purchase_mask=purchase_mask, - event_weights=event_weights, - event_names=event_names, - start_idx=start_idx, - term_idx=term_idx, - ) - - -def fallback_transition_data() -> TransitionData: - human = { - "session_start": { - "page_view": 0.80, - "view_item_page": 0.15, - "session_end": 0.05, - }, - "page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20}, - "view_item_page": { - "learn_more_about_item": 0.40, - "add_item_to_cart": 0.28, - "session_end": 0.32, - }, - "learn_more_about_item": { - "add_item_to_cart": 0.50, - "view_item_page": 0.30, - "session_end": 0.20, - }, - "add_item_to_cart": { - "checkout_start": 0.58, - "view_item_page": 0.24, - "session_end": 0.18, - }, - "checkout_start": {"purchase_complete": 0.70, "session_end": 0.30}, - "purchase_complete": {"session_end": 1.0}, - } - agent = { - "session_start": { - "page_view": 0.90, - "view_item_page": 0.08, - "session_end": 0.02, - }, - "page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25}, - "view_item_page": { - "learn_more_about_item": 0.55, - "add_item_to_cart": 0.15, - "session_end": 0.30, - }, - "learn_more_about_item": { - "view_item_page": 0.45, - "add_item_to_cart": 0.20, - "session_end": 0.35, - }, - "add_item_to_cart": { - "checkout_start": 0.42, - "view_item_page": 0.28, - "session_end": 0.30, - }, - "checkout_start": {"purchase_complete": 0.52, "session_end": 0.48}, - "purchase_complete": {"session_end": 1.0}, - } - return compile_transition_data(human, agent) - - -def load_transition_data(prefer_data: bool = True) -> TransitionData: - if not prefer_data: - return fallback_transition_data() - try: - from ..lib.behavior import get_transition_models - - human_trans, agent_trans = get_transition_models() - return compile_transition_data(human_trans, agent_trans) - except Exception: - return fallback_transition_data() - - -if JAX_AVAILABLE: - - @partial(jax.jit, static_argnums=(8, 9, 10)) - def _sample_sessions_jax( - key: jax.Array, - human_T: jax.Array, - agent_T: jax.Array, - terminal_mask: jax.Array, - start_idx: int, - term_idx: int, - alpha: float, - n_products: int, - n_sessions: int, - max_steps: int, - n_states: int, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: - k_actor, k_product, k_step = jax.random.split(key, 3) - start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32) - term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32) - actor_draw = jax.random.uniform(k_actor, (n_sessions,)) - actors = (actor_draw < alpha).astype(jnp.int32) - products = jax.random.randint( - k_product, (n_sessions,), 0, n_products, dtype=jnp.int32 - ) - - active_init = jnp.ones((n_sessions,), dtype=jnp.bool_) - state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32) - - def _scan_step(carry, _): - states, active, rng = carry - rng, k = jax.random.split(rng) - probs_h = human_T[states] - probs_a = agent_T[states] - probs = jnp.where(actors[:, None] == 0, probs_h, probs_a) - next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1) - next_state = jnp.where(active, next_state, term_idx_i32) - emitted = jnp.where(active, next_state, -1) - is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)] - next_active = active & (~is_terminal) - carry_states = jnp.where(next_active, next_state, term_idx_i32) - return (carry_states, next_active, rng), emitted - - _, state_t = jax.lax.scan( - _scan_step, (state_init, active_init, k_step), None, length=max_steps - ) - states = state_t.T - lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32) - return states, products, actors, lengths - - -def sample_sessions( - key, - transition_data: TransitionData, - alpha: float, - n_products: int, - n_sessions: int, - max_steps: int, -) -> SessionBatch: - if JAX_AVAILABLE: - td = transition_data.to_jax() - states, products, actors, lengths = _sample_sessions_jax( - key, - td.human_T, - td.agent_T, - td.terminal_mask, - int(td.start_idx), - int(td.term_idx), - float(alpha), - int(n_products), - int(n_sessions), - int(max_steps), - int(td.human_T.shape[0]), - ) - return SessionBatch( - states=states, products=products, actors=actors, lengths=lengths - ) - - rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0])) - n_states = transition_data.human_T.shape[0] - products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32) - actors = (rng.random(size=n_sessions) < alpha).astype(np.int32) - states = np.full((n_sessions, max_steps), -1, dtype=np.int32) - lengths = np.zeros((n_sessions,), dtype=np.int32) - for i in range(n_sessions): - current = int(transition_data.start_idx) - mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T - for t in range(max_steps): - nxt = int(rng.choice(n_states, p=mat[current])) - states[i, t] = nxt - if transition_data.terminal_mask[nxt]: - lengths[i] = t + 1 - break - current = nxt - if lengths[i] == 0: - lengths[i] = max_steps - return SessionBatch( - states=states, products=products, actors=actors, lengths=lengths - ) - - -if JAX_AVAILABLE: - - @partial(jax.jit, static_argnums=(2,)) - def compute_session_transitions(states, lengths, n_states: int): - src = states[:, :-1] - dst = states[:, 1:] - time_idx = jnp.arange(src.shape[1])[None, :] - valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1)) - src_clip = jnp.clip(src, 0, n_states - 1) - dst_clip = jnp.clip(dst, 0, n_states - 1) - src_oh = jax.nn.one_hot(src_clip, n_states) - dst_oh = jax.nn.one_hot(dst_clip, n_states) - counts = jnp.einsum( - "nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32) - ) - row_sums = jnp.sum(counts, axis=-1, keepdims=True) - return counts / (row_sums + 1e-10) - - -else: - - def compute_session_transitions(states, lengths, n_states: int): - trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32) - for i in range(states.shape[0]): - for t in range(max(int(lengths[i]) - 1, 0)): - s = int(states[i, t]) - d = int(states[i, t + 1]) - if s >= 0 and d >= 0: - trans[i, s, d] += 1.0 - row_sums = trans.sum(axis=-1, keepdims=True) - return trans / (row_sums + 1e-10) - - -def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10): - p = P + eps - p = p / jnp.sum(p, axis=-1, keepdims=True) - qh = Q_human[None, ...] + eps - qa = Q_agent[None, ...] + eps - delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2)) - delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2)) - return delta_h, delta_a - - -if JAX_AVAILABLE: - batch_kl = jax.jit(batch_kl) - - -def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0): - t = jnp.maximum(float(temperature), 1e-6) - exp_h = jnp.exp(-delta_h / t) - exp_a = jnp.exp(-delta_a / t) - return exp_a / (exp_h + exp_a + 1e-10) - - -def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0): - logits = beta * (delta_h - delta_a) - return 1.0 / (1.0 + jnp.exp(-logits)) - - -def weighted_demand(states, products, n_products: int, event_weights): - valid = states >= 0 - state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1) - weights = event_weights[state_clip] * valid - per_session = jnp.sum(weights, axis=1) - demand = jnp.zeros((n_products,), dtype=jnp.float32) - demand = demand.at[products].add(per_session) - total = jnp.sum(demand) - return jnp.where(total > 0.0, (demand / total) * 100.0, demand) - - -if JAX_AVAILABLE: - weighted_demand = jax.jit(weighted_demand, static_argnums=(2,)) - - -def purchase_flags(states, purchase_mask): - state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1) - hits = purchase_mask[state_clip] & (states >= 0) - return jnp.any(hits, axis=1) - - -if JAX_AVAILABLE: - purchase_flags = jax.jit(purchase_flags) - - -def revenue_from_demand(prices, demand): - return jnp.dot(prices, demand) - - -if JAX_AVAILABLE: - revenue_from_demand = jax.jit(revenue_from_demand) - - -def reward_with_coi_penalty( - revenue, - agent_prob: float, - lambda_coi: float, - info_value: float, - eta_ux: float = 0.0, - ux_volatility: float = 0.0, -): - leakage = agent_prob * info_value - discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0) - ux_penalty = eta_ux * revenue * ux_volatility - return revenue * discount - ux_penalty, leakage, discount, ux_penalty - - -if JAX_AVAILABLE: - reward_with_coi_penalty = jax.jit(reward_with_coi_penalty) diff --git a/engine/jax/requirements.txt b/engine/jax/requirements.txt deleted file mode 100644 index 7bde61c..0000000 --- a/engine/jax/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -flax==0.10.7 -optax==0.2.7 -distrax==0.1.5 -orbax-checkpoint==0.11.32 -chex==0.1.90 diff --git a/engine/jax/train.py b/engine/jax/train.py deleted file mode 100644 index 5ec637c..0000000 --- a/engine/jax/train.py +++ /dev/null @@ -1,1319 +0,0 @@ -"""Pure JAX trainers for PHANTOM environment.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, NamedTuple - -import signal -import threading - -import numpy as np - -_stop_requested = threading.Event() -_jax_dist_initialized = False - - -def _init_jax_distributed() -> None: - """Initialize JAX distributed if running on a multi-host TPU pod. - Safe to call multiple times; no-op after first successful init or when JAX unavailable.""" - global _jax_dist_initialized - if _jax_dist_initialized: - return - _jax_dist_initialized = True - try: - import jax as _jax - - _jax.distributed.initialize() - except Exception: - pass - - -try: - import wandb - - HAS_WANDB = True -except ImportError: - HAS_WANDB = False - -from ..wandb_checkpoint import ( # noqa: E402 - checkpoint_artifact_name, - download_latest_checkpoint, - log_checkpoint_bytes, -) - -try: - import jax - import jax.numpy as jnp - import distrax - import flax.linen as nn - import optax - from flax import serialization - from flax.linen.initializers import constant, orthogonal - from flax.training.train_state import TrainState - - HAS_JAX_STACK = True -except ImportError: - jax = None # type: ignore[assignment] - jnp = None # type: ignore[assignment] - distrax = None # type: ignore[assignment] - optax = None # type: ignore[assignment] - serialization = None # type: ignore[assignment] - - class _ModuleStub: - pass - - class _NNStub: - Module = _ModuleStub - - @staticmethod - def compact(fn): - return fn - - nn = _NNStub() # type: ignore[assignment] - - def constant(*_args, **_kwargs): # type: ignore[override] - return None - - def orthogonal(*_args, **_kwargs): # type: ignore[override] - return None - - class TrainState: # type: ignore[override] - pass - - HAS_JAX_STACK = False - -from .env import PHANTOMJAXEnv, make_env_params # noqa: E402 - - -class ActorCritic(nn.Module): - action_dim: int - activation: str = "tanh" - - @nn.compact - def __call__(self, x): - activation_fn = nn.relu if self.activation == "relu" else nn.tanh - - actor = nn.Dense( - 64, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(x) - actor = activation_fn(actor) - actor = nn.Dense( - 64, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(actor) - actor = activation_fn(actor) - logits = nn.Dense( - self.action_dim, - kernel_init=orthogonal(0.01), - bias_init=constant(0.0), - )(actor) - - critic = nn.Dense( - 64, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(x) - critic = activation_fn(critic) - critic = nn.Dense( - 64, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(critic) - critic = activation_fn(critic) - value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( - critic - ) - return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1) - - -class QNetwork(nn.Module): - action_dim: int - activation: str = "relu" - - @nn.compact - def __call__(self, x): - activation_fn = nn.relu if self.activation == "relu" else nn.tanh - x = nn.Dense( - 128, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(x) - x = activation_fn(x) - x = nn.Dense( - 128, - kernel_init=orthogonal(np.sqrt(2.0)), - bias_init=constant(0.0), - )(x) - x = activation_fn(x) - q_values = nn.Dense( - self.action_dim, - kernel_init=orthogonal(1.0), - bias_init=constant(0.0), - )(x) - return q_values - - -class Transition(NamedTuple): - done: jax.Array - action: jax.Array - value: jax.Array - reward: jax.Array - log_prob: jax.Array - obs: jax.Array - info: dict[str, jax.Array] - - -class ReplayBatch(NamedTuple): - obs: jax.Array - actions: jax.Array - rewards: jax.Array - next_obs: jax.Array - dones: jax.Array - - -class ReplayBuffer(NamedTuple): - obs: jax.Array - actions: jax.Array - rewards: jax.Array - next_obs: jax.Array - dones: jax.Array - ptr: jax.Array - size: jax.Array - - -def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: - out = { - "algo": str(cfg.get("algo", "ppo")).lower(), - "seed": int(cfg.get("seed", 42)), - "learning_rate": float(cfg.get("learning_rate", 3e-4)), - "gamma": float(cfg.get("gamma", 0.99)), - "gae_lambda": float(cfg.get("gae_lambda", 0.95)), - "clip_range": float(cfg.get("clip_range", 0.2)), - "ent_coef": float(cfg.get("ent_coef", 0.01)), - "vf_coef": float(cfg.get("vf_coef", 0.5)), - "max_grad_norm": float(cfg.get("max_grad_norm", 0.5)), - "activation": str(cfg.get("activation", "relu")), - "total_timesteps": int(cfg.get("total_timesteps", 50_000)), - "eval_episodes": int(cfg.get("eval_episodes", 5)), - "model_dir": str(cfg.get("model_dir", "engine/models")), - "log_freq": int(cfg.get("log_freq", 100)), - "n_products": int(cfg.get("n_products", 10)), - "N": int(cfg.get("N", 100)), - "alpha": float(cfg.get("alpha", 0.3)), - "lambda_coi": float(cfg.get("lambda_coi", 0.2)), - "robust_radius": float(cfg.get("robust_radius", 0.15)), - "robust_points": int(cfg.get("robust_points", 5)), - "info_value": float(cfg.get("info_value", 1.0)), - "price_low": float(cfg.get("price_low", 10.0)), - "price_high": float(cfg.get("price_high", 150.0)), - "action_levels": int(cfg.get("action_levels", 9)), - "action_scale_low": float(cfg.get("action_scale_low", 0.8)), - "action_scale_high": float(cfg.get("action_scale_high", 1.2)), - "max_episode_steps": int(cfg.get("max_steps", 100)), - "max_session_steps": int(cfg.get("max_session_steps", 40)), - "margin_floor": float(cfg.get("margin_floor", 0.05)), - "margin_floor_patience": int(cfg.get("margin_floor_patience", 5)), - "prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)), - "num_envs": int(cfg.get("jax_num_envs", 16)), - "num_steps": int(cfg.get("jax_num_steps", 128)), - "num_minibatches": int(cfg.get("jax_num_minibatches", 4)), - "update_epochs": int(cfg.get("jax_update_epochs", 4)), - "anneal_lr": bool(cfg.get("jax_anneal_lr", True)), - "checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)), - "buffer_size": int(cfg.get("buffer_size", 50_000)), - "batch_size": int(cfg.get("batch_size", 256)), - "train_freq": int(cfg.get("train_freq", 1)), - "learning_starts": int(cfg.get("learning_starts", 1_000)), - "target_update_interval": int(cfg.get("target_update_interval", 1_000)), - "exploration_fraction": float(cfg.get("exploration_fraction", 0.2)), - "exploration_final_eps": float(cfg.get("exploration_final_eps", 0.05)), - "eps_start": float(cfg.get("eps_start", 1.0)), - "eps_end": float(cfg.get("eps_end", 0.05)), - "eps_decay": float(cfg.get("eps_decay", 0.9995)), - "q_lr": float(cfg.get("q_lr", 0.1)), - "q_bins": int(cfg.get("q_bins", 6)), - } - rollout = out["num_envs"] * out["num_steps"] - out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1)) - out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1)) - return out - - -def _scalar(value: Any) -> float: - return float(np.asarray(value)) - - -def _scalar_int(value: Any) -> int: - return int(np.asarray(value)) - - -def _make_env(cfg: dict[str, Any]) -> PHANTOMJAXEnv: - env_params = make_env_params( - n_products=cfg["n_products"], - alpha=cfg["alpha"], - n_sessions=cfg["N"], - lambda_coi=cfg["lambda_coi"], - robust_radius=cfg["robust_radius"], - robust_points=cfg["robust_points"], - info_value=cfg["info_value"], - action_levels=cfg["action_levels"], - action_scale_low=cfg["action_scale_low"], - action_scale_high=cfg["action_scale_high"], - price_low=cfg["price_low"], - price_high=cfg["price_high"], - max_episode_steps=cfg["max_episode_steps"], - max_session_steps=cfg["max_session_steps"], - margin_floor=cfg["margin_floor"], - margin_floor_patience=cfg["margin_floor_patience"], - prefer_behavior_data=cfg["prefer_behavior_data"], - ) - return PHANTOMJAXEnv(env_params) - - -def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array: - mask = done - while mask.ndim < keep.ndim: - mask = mask[..., None] - return jnp.where(mask, reset, keep) - - -def _epsilon_by_fraction(step: int, cfg: dict[str, Any]) -> float: - start = float(cfg["eps_start"]) - end = float(cfg["exploration_final_eps"]) - frac = float(cfg["exploration_fraction"]) - total = max(1, int(cfg["total_timesteps"])) - decay_steps = max(1, int(total * frac)) - if step >= decay_steps: - return end - slope = (end - start) / decay_steps - return float(start + slope * step) - - -def _digitize_scalar(value: jax.Array, bins: jax.Array) -> jax.Array: - return jnp.sum(value > bins).astype(jnp.int32) - - -def _encode_qtable_state( - obs: jax.Array, - *, - n_products: int, - demand_bins: jax.Array, - price_bins: jax.Array, -) -> tuple[jax.Array, jax.Array, jax.Array]: - demand = obs[:n_products] - prices = obs[n_products : 2 * n_products] - d_mean = jnp.mean(demand) - d_std = jnp.std(demand) - p_mean = jnp.mean(prices) - return ( - _digitize_scalar(d_mean, demand_bins), - _digitize_scalar(d_std, demand_bins), - _digitize_scalar(p_mean, price_bins), - ) - - -def _init_replay_buffer(capacity: int, obs_dim: int) -> ReplayBuffer: - cap = max(1, int(capacity)) - return ReplayBuffer( - obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32), - actions=jnp.zeros((cap,), dtype=jnp.int32), - rewards=jnp.zeros((cap,), dtype=jnp.float32), - next_obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32), - dones=jnp.zeros((cap,), dtype=jnp.float32), - ptr=jnp.asarray(0, dtype=jnp.int32), - size=jnp.asarray(0, dtype=jnp.int32), - ) - - -def _replay_size(buffer: ReplayBuffer) -> int: - return _scalar_int(buffer.size) - - -def _replay_add( - buffer: ReplayBuffer, - obs: jax.Array, - action: jax.Array, - reward: jax.Array, - next_obs: jax.Array, - done: jax.Array, -) -> ReplayBuffer: - capacity = int(buffer.obs.shape[0]) - idx = buffer.ptr % capacity - return ReplayBuffer( - obs=buffer.obs.at[idx].set(obs.astype(jnp.float32)), - actions=buffer.actions.at[idx].set(action.astype(jnp.int32)), - rewards=buffer.rewards.at[idx].set(reward.astype(jnp.float32)), - next_obs=buffer.next_obs.at[idx].set(next_obs.astype(jnp.float32)), - dones=buffer.dones.at[idx].set(done.astype(jnp.float32)), - ptr=buffer.ptr + 1, - size=jnp.minimum(buffer.size + 1, jnp.asarray(capacity, dtype=jnp.int32)), - ) - - -def _replay_sample( - buffer: ReplayBuffer, key: jax.Array, batch_size: int -) -> ReplayBatch: - size = jnp.maximum(buffer.size, 1) - idx = jax.random.randint(key, shape=(batch_size,), minval=0, maxval=size) - return ReplayBatch( - obs=buffer.obs[idx], - actions=buffer.actions[idx], - rewards=buffer.rewards[idx], - next_obs=buffer.next_obs[idx], - dones=buffer.dones[idx], - ) - - -def _make_actor_critic_train( - config: dict[str, Any], *, algo: str, use_pmap: bool = False -): - cfg = dict(config) - cfg["algo"] = algo - env = _make_env(cfg) - network = ActorCritic(env.action_space_n(), activation=cfg["activation"]) - - def linear_schedule(count: jax.Array) -> jax.Array: - updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"]) - frac = 1.0 - updates_done / max(cfg["num_updates"], 1) - return cfg["learning_rate"] * frac - - if cfg["anneal_lr"]: - tx = optax.chain( - optax.clip_by_global_norm(cfg["max_grad_norm"]), - optax.adam(learning_rate=linear_schedule, eps=1e-5), - ) - else: - tx = optax.chain( - optax.clip_by_global_norm(cfg["max_grad_norm"]), - optax.adam(cfg["learning_rate"], eps=1e-5), - ) - - def init_runner_state(rng: jax.Array): - rng, init_key = jax.random.split(rng) - init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32) - params = network.init(init_key, init_obs) - train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx) - - rng, reset_key = jax.random.split(rng) - reset_keys = jax.random.split(reset_key, cfg["num_envs"]) - obs, env_state = jax.vmap(env.reset)(reset_keys) - return train_state, env_state, obs, rng - - def _update_step(runner_state, _): - def _env_step(runner_state, _): - train_state, env_state, last_obs, rng = runner_state - rng, action_key = jax.random.split(rng) - policy, value = network.apply(train_state.params, last_obs) - action = policy.sample(seed=action_key) - log_prob = policy.log_prob(action) - - rng, step_key = jax.random.split(rng) - step_keys = jax.random.split(step_key, cfg["num_envs"]) - nxt_obs, nxt_state, reward, done, info = jax.vmap( - env.step, - in_axes=(0, 0, 0), - )(step_keys, env_state, action) - - rng, reset_key = jax.random.split(rng) - reset_keys = jax.random.split(reset_key, cfg["num_envs"]) - rst_obs, rst_state = jax.vmap(env.reset)(reset_keys) - obs_next = jnp.where(done[:, None], rst_obs, nxt_obs) - env_next = jax.tree_util.tree_map( - lambda keep, reset: _select_env_state(done, keep, reset), - nxt_state, - rst_state, - ) - transition = Transition( - done=done, - action=action, - value=value, - reward=reward, - log_prob=log_prob, - obs=last_obs, - info=info, - ) - return (train_state, env_next, obs_next, rng), transition - - runner_state, traj_batch = jax.lax.scan( - _env_step, - runner_state, - None, - length=cfg["num_steps"], - ) - - train_state, env_state, last_obs, rng = runner_state - _, last_value = network.apply(train_state.params, last_obs) - - def _compute_gae(traj_batch, last_value): - def _gae_step(carry, transition): - gae, next_value = carry - delta = ( - transition.reward - + cfg["gamma"] * next_value * (1.0 - transition.done) - - transition.value - ) - gae = ( - delta - + cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae - ) - return (gae, transition.value), gae - - _, advantages = jax.lax.scan( - _gae_step, - (jnp.zeros_like(last_value), last_value), - traj_batch, - reverse=True, - unroll=16, - ) - targets = advantages + traj_batch.value - return advantages, targets - - advantages, targets = _compute_gae(traj_batch, last_value) - - def _update_epoch(update_state, _): - def _update_minibatch(train_state, batch_info): - traj_b, adv_b, tgt_b = batch_info - - def _loss_fn(params, traj_b, adv_b, tgt_b): - policy, value = network.apply(params, traj_b.obs) - log_prob = policy.log_prob(traj_b.action) - adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8) - - if algo == "ppo": - value_clipped = traj_b.value + (value - traj_b.value).clip( - -cfg["clip_range"], cfg["clip_range"] - ) - value_loss = ( - 0.5 - * jnp.maximum( - jnp.square(value - tgt_b), - jnp.square(value_clipped - tgt_b), - ).mean() - ) - ratio = jnp.exp(log_prob - traj_b.log_prob) - policy_loss = -jnp.minimum( - ratio * adv_norm, - jnp.clip( - ratio, - 1.0 - cfg["clip_range"], - 1.0 + cfg["clip_range"], - ) - * adv_norm, - ).mean() - else: - value_loss = 0.5 * jnp.mean(jnp.square(value - tgt_b)) - policy_loss = -(log_prob * adv_norm).mean() - - entropy = policy.entropy().mean() - total_loss = ( - policy_loss - + cfg["vf_coef"] * value_loss - - cfg["ent_coef"] * entropy - ) - return total_loss, (value_loss, policy_loss, entropy) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b) - if use_pmap: - grads = jax.lax.pmean(grads, axis_name="devices") - train_state = train_state.apply_gradients(grads=grads) - return train_state, jnp.asarray(0.0, dtype=jnp.float32) - - train_state, traj_batch, advantages, targets, rng = update_state - rng, perm_key = jax.random.split(rng) - batch_size = cfg["num_envs"] * cfg["num_steps"] - permutation = jax.random.permutation(perm_key, batch_size) - - batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), - batch, - ) - shuffled = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=0), - batch, - ) - minibatches = jax.tree_util.tree_map( - lambda x: x.reshape( - (cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:] - ), - shuffled, - ) - train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches) - return (train_state, traj_batch, advantages, targets, rng), None - - update_state = (train_state, traj_batch, advantages, targets, rng) - update_state, _ = jax.lax.scan( - _update_epoch, - update_state, - None, - length=cfg["update_epochs"], - ) - train_state = update_state[0] - rng = update_state[-1] - - metric = { - "reward": jnp.mean(traj_batch.reward), - "revenue": jnp.mean(traj_batch.info["revenue"]), - "agent_prob": jnp.mean(traj_batch.info["agent_prob"]), - "alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]), - "coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]), - } - next_runner_state = (train_state, env_state, last_obs, rng) - return next_runner_state, metric - - def run_updates(runner_state, num_updates: int): - updates = max(1, int(num_updates)) - runner_state, metric = jax.lax.scan( - _update_step, - runner_state, - None, - length=updates, - ) - return { - "runner_state": runner_state, - "metrics": metric, - } - - return init_runner_state, run_updates, network, env, cfg - - -def make_train(config: dict[str, Any]): - cfg = _jax_cfg(config) - algo = cfg["algo"] - if algo not in {"ppo", "a2c"}: - raise ValueError(f"make_train supports actor-critic algos only, got '{algo}'") - return _make_actor_critic_train(cfg, algo=algo) - - -def evaluate_policy( - *, - network: ActorCritic, - params: Any, - env: PHANTOMJAXEnv, - episodes: int, - seed: int, -) -> dict[str, float]: - rewards: list[float] = [] - revenues: list[float] = [] - key = jax.random.PRNGKey(seed) - - for _ in range(int(episodes)): - key, reset_key = jax.random.split(key) - obs, state = env.reset(reset_key) - ep_reward = 0.0 - ep_revenue = 0.0 - done = False - steps = 0 - - while not done and steps < int(env.params.max_episode_steps): - policy, _ = network.apply(params, obs) - action = jnp.argmax(policy.logits) - key, step_key = jax.random.split(key) - obs, state, reward, done_flag, info = env.step(step_key, state, action) - ep_reward += _scalar(reward) - ep_revenue += _scalar(info["revenue"]) - done = bool(np.asarray(done_flag)) - steps += 1 - - rewards.append(ep_reward) - revenues.append(ep_revenue) - - return { - "eval/reward_mean": float(np.mean(rewards)), - "eval/revenue_mean": float(np.mean(revenues)), - "eval/reward_std": float(np.std(rewards)), - "eval/revenue_std": float(np.std(revenues)), - } - - -def _evaluate_q_network( - *, - network: QNetwork, - params: Any, - env: PHANTOMJAXEnv, - episodes: int, - seed: int, -) -> dict[str, float]: - rewards: list[float] = [] - revenues: list[float] = [] - key = jax.random.PRNGKey(seed) - - for _ in range(int(episodes)): - key, reset_key = jax.random.split(key) - obs, state = env.reset(reset_key) - ep_reward = 0.0 - ep_revenue = 0.0 - done = False - steps = 0 - - while not done and steps < int(env.params.max_episode_steps): - q_values = network.apply(params, obs) - action = jnp.argmax(q_values) - key, step_key = jax.random.split(key) - obs, state, reward, done_flag, info = env.step(step_key, state, action) - ep_reward += _scalar(reward) - ep_revenue += _scalar(info["revenue"]) - done = bool(np.asarray(done_flag)) - steps += 1 - - rewards.append(ep_reward) - revenues.append(ep_revenue) - - return { - "eval/reward_mean": float(np.mean(rewards)), - "eval/revenue_mean": float(np.mean(revenues)), - "eval/reward_std": float(np.std(rewards)), - "eval/revenue_std": float(np.std(revenues)), - } - - -def _evaluate_q_table( - *, - q_table: jax.Array, - env: PHANTOMJAXEnv, - episodes: int, - seed: int, - n_products: int, - demand_bins: jax.Array, - price_bins: jax.Array, -) -> dict[str, float]: - rewards: list[float] = [] - revenues: list[float] = [] - key = jax.random.PRNGKey(seed) - - for _ in range(int(episodes)): - key, reset_key = jax.random.split(key) - obs, state = env.reset(reset_key) - ep_reward = 0.0 - ep_revenue = 0.0 - done = False - steps = 0 - - while not done and steps < int(env.params.max_episode_steps): - s0, s1, s2 = _encode_qtable_state( - obs, - n_products=n_products, - demand_bins=demand_bins, - price_bins=price_bins, - ) - action = jnp.argmax(q_table[s0, s1, s2]) - key, step_key = jax.random.split(key) - obs, state, reward, done_flag, info = env.step(step_key, state, action) - ep_reward += _scalar(reward) - ep_revenue += _scalar(info["revenue"]) - done = bool(np.asarray(done_flag)) - steps += 1 - - rewards.append(ep_reward) - revenues.append(ep_revenue) - - return { - "eval/reward_mean": float(np.mean(rewards)), - "eval/revenue_mean": float(np.mean(revenues)), - "eval/reward_std": float(np.std(rewards)), - "eval/revenue_std": float(np.std(revenues)), - } - - -def _train_actor_critic( - cfg: dict[str, Any], - *, - algo: str, -) -> tuple[dict[str, Any], dict[str, float]]: - num_devices = jax.local_device_count() - use_pmap = num_devices > 1 - global_devices = max(1, int(jax.device_count())) - process_idx = int(jax.process_index()) - - init_runner_state, run_updates_raw, network, env, run_cfg = ( - _make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap) - ) - - if use_pmap: - run_fn = jax.pmap( - run_updates_raw, - axis_name="devices", - static_broadcasted_argnums=(1,), - devices=jax.local_devices(), - ) - else: - run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",)) - - rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"]) - rollout_steps_global = rollout_steps * (global_devices if use_pmap else 1) - total_updates = int(run_cfg["num_updates"]) - checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000))) - segment_updates = max(1, checkpoint_interval // max(rollout_steps_global, 1)) - - base_rng = jax.random.PRNGKey(run_cfg["seed"]) - base_rng = jax.random.fold_in(base_rng, process_idx) - if use_pmap: - init_keys = jax.random.split(base_rng, num_devices) - runner_state = jax.vmap(init_runner_state)(init_keys) - single_runner_state = jax.tree_util.tree_map(lambda x: x[0], runner_state) - else: - single_runner_state = init_runner_state(base_rng) - runner_state = single_runner_state - updates_done = 0 - restored_train_state = None - - is_primary = process_idx == 0 - artifact_name = None - if HAS_WANDB and wandb.run is not None: - sweep_id = getattr(wandb.run, "sweep_id", None) - artifact_name = checkpoint_artifact_name( - run_cfg, - backend="jax", - sweep_id=sweep_id, - ) - restored = download_latest_checkpoint( - artifact_name, - file_name=f"jax_{algo}_runner_state.msgpack", - ) - if restored is not None: - checkpoint_path, metadata = restored - template = {"runner_state": single_runner_state, "updates_done": 0} - payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) - single_runner_state = payload["runner_state"] - restored_train_state = payload["runner_state"][0] - updates_done = int(payload.get("updates_done", 0)) - if updates_done <= 0: - updates_done = int(metadata.get("updates_done", 0)) - updates_done = max(0, min(updates_done, total_updates)) - - if use_pmap and restored_train_state is not None: - runner_state = ( - jax.device_put_replicated(restored_train_state, jax.local_devices()), - runner_state[1], - runner_state[2], - runner_state[3], - ) - elif not use_pmap: - runner_state = single_runner_state - - metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"] - metric_sums = {k: 0.0 for k in metric_keys} - metric_count = 0 - - while updates_done < total_updates: - updates_this_segment = min(segment_updates, total_updates - updates_done) - if use_pmap: - out = run_fn(runner_state, updates_this_segment) - else: - out = run_fn(runner_state, updates_this_segment) - runner_state = out["runner_state"] - metric = out["metrics"] - - if use_pmap: - segment_values = { - key: np.asarray(metric[key], dtype=np.float64).reshape(-1) - for key in metric_keys - } - else: - segment_values = { - key: np.asarray(metric[key], dtype=np.float64).reshape(-1) - for key in metric_keys - } - - segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0 - metric_count += segment_count - for key in metric_keys: - metric_sums[key] += float(segment_values[key].sum()) - - updates_done += int(updates_this_segment) - global_step = int(updates_done * rollout_steps_global) - - if is_primary and HAS_WANDB and wandb.run is not None: - wandb.log( - { - "train/reward_mean": float(segment_values["reward"].mean()), - "train/revenue_mean": float(segment_values["revenue"].mean()), - "train/agent_prob": float(segment_values["agent_prob"].mean()), - "train/alpha_adv": float(segment_values["alpha_adv"].mean()), - "train/coi_leakage": float(segment_values["coi_leakage"].mean()), - "train/global_step": global_step, - }, - step=global_step, - ) - if artifact_name is not None: - # extract device-0 state for checkpoint portability - state_to_save = ( - jax.tree_util.tree_map(lambda x: x[0], runner_state) - if use_pmap - else runner_state - ) - checkpoint_payload = serialization.to_bytes( - {"runner_state": state_to_save, "updates_done": updates_done} - ) - log_checkpoint_bytes( - artifact_name, - file_name=f"jax_{algo}_runner_state.msgpack", - payload=checkpoint_payload, - metadata={ - "step": global_step, - "updates_done": updates_done, - "rollout_steps": rollout_steps_global, - "algo": algo, - }, - ) - if _stop_requested.is_set(): - break - - # extract device-0 params for eval and save - final_runner = ( - jax.tree_util.tree_map(lambda x: x[0], runner_state) - if use_pmap - else runner_state - ) - train_state = final_runner[0] - denom = float(metric_count) if metric_count > 0 else 1.0 - metrics = { - "train/reward_mean": float(metric_sums["reward"] / denom), - "train/revenue_mean": float(metric_sums["revenue"] / denom), - "train/agent_prob": float(metric_sums["agent_prob"] / denom), - "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), - "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), - "train/global_step": int(updates_done * rollout_steps_global), - } - - eval_metrics = evaluate_policy( - network=network, - params=train_state.params, - env=env, - episodes=run_cfg["eval_episodes"], - seed=run_cfg["seed"] + 7, - ) - metrics.update(eval_metrics) - - if is_primary: - model_dir = Path(run_cfg["model_dir"]) - model_dir.mkdir(parents=True, exist_ok=True) - model_path = model_dir / f"phantom_{algo}_jax.msgpack" - model_path.write_bytes(serialization.to_bytes(train_state.params)) - metrics["model/path"] = str(model_path) - - return {"params": train_state.params}, metrics - - -def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: - run_cfg = dict(cfg) - env = _make_env(run_cfg) - action_dim = env.action_space_n() - obs_dim = env.observation_dim() - q_net = QNetwork(action_dim=action_dim, activation=run_cfg["activation"]) - - init_obs = jnp.zeros((obs_dim,), dtype=jnp.float32) - rng = jax.random.PRNGKey(run_cfg["seed"]) - rng, init_key = jax.random.split(rng) - params = q_net.init(init_key, init_obs) - tx = optax.adam(run_cfg["learning_rate"]) - train_state = TrainState.create(apply_fn=q_net.apply, params=params, tx=tx) - target_params = train_state.params - - buffer = _init_replay_buffer(run_cfg["buffer_size"], obs_dim) - - rng, reset_key = jax.random.split(rng) - obs, env_state = env.reset(reset_key) - - start_step = 0 - epsilon_value = float(run_cfg["eps_start"]) - artifact_name = None - - if HAS_WANDB and wandb.run is not None: - sweep_id = getattr(wandb.run, "sweep_id", None) - artifact_name = checkpoint_artifact_name( - run_cfg, - backend="jax", - sweep_id=sweep_id, - ) - restored = download_latest_checkpoint( - artifact_name, - file_name="jax_dqn_state.msgpack", - ) - if restored is not None: - checkpoint_path, metadata = restored - template = { - "params": train_state.params, - "target_params": target_params, - "opt_state": train_state.opt_state, - "global_step": 0, - "epsilon": epsilon_value, - } - payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) - train_state = train_state.replace( - params=payload["params"], - opt_state=payload["opt_state"], - ) - target_params = payload["target_params"] - start_step = int(payload.get("global_step", metadata.get("step", 0))) - start_step = max(0, min(start_step, int(run_cfg["total_timesteps"]))) - epsilon_value = float(payload.get("epsilon", epsilon_value)) - - @jax.jit - def dqn_update( - state: TrainState, - target: Any, - batch: ReplayBatch, - ) -> tuple[TrainState, jax.Array]: - def loss_fn(model_params): - q_values = q_net.apply(model_params, batch.obs) - chosen = jnp.take_along_axis( - q_values, - batch.actions[:, None], - axis=1, - ).squeeze(-1) - next_q = q_net.apply(target, batch.next_obs) - next_max = jnp.max(next_q, axis=1) - td_target = ( - batch.rewards + run_cfg["gamma"] * (1.0 - batch.dones) * next_max - ) - td_error = chosen - jax.lax.stop_gradient(td_target) - return jnp.mean(jnp.square(td_error)) - - loss, grads = jax.value_and_grad(loss_fn)(state.params) - next_state = state.apply_gradients(grads=grads) - return next_state, loss - - metric_sums = { - "reward": 0.0, - "revenue": 0.0, - "agent_prob": 0.0, - "alpha_adv": 0.0, - "coi_leakage": 0.0, - "loss": 0.0, - } - metric_count = 0 - loss_count = 0 - - total_steps = int(run_cfg["total_timesteps"]) - checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"])) - batch_size = max(1, int(run_cfg["batch_size"])) - - for global_step in range(start_step + 1, total_steps + 1): - epsilon_value = _epsilon_by_fraction(global_step - 1, run_cfg) - - rng, eps_key, action_key, step_key, reset_key, sample_key = jax.random.split( - rng, 6 - ) - do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value)) - if do_explore: - action = jax.random.randint( - action_key, shape=(), minval=0, maxval=action_dim - ) - else: - q_values = q_net.apply(train_state.params, obs) - action = jnp.argmax(q_values) - - next_obs, next_state, reward, done, info = env.step(step_key, env_state, action) - buffer = _replay_add( - buffer, - obs, - action, - reward, - next_obs, - done.astype(jnp.float32), - ) - - metric_count += 1 - metric_sums["reward"] += _scalar(reward) - metric_sums["revenue"] += _scalar(info["revenue"]) - metric_sums["agent_prob"] += _scalar(info["agent_prob"]) - metric_sums["alpha_adv"] += _scalar(info["alpha_adv"]) - metric_sums["coi_leakage"] += _scalar(info["coi_leakage"]) - - if bool(np.asarray(done)): - obs, env_state = env.reset(reset_key) - else: - obs, env_state = next_obs, next_state - - ready = ( - global_step >= int(run_cfg["learning_starts"]) - and global_step % int(run_cfg["train_freq"]) == 0 - and _replay_size(buffer) >= batch_size - ) - if ready: - batch = _replay_sample(buffer, sample_key, batch_size) - train_state, loss = dqn_update(train_state, target_params, batch) - metric_sums["loss"] += _scalar(loss) - loss_count += 1 - - if global_step % int(run_cfg["target_update_interval"]) == 0: - target_params = train_state.params - - if ( - HAS_WANDB - and wandb.run is not None - and global_step % int(run_cfg["log_freq"]) == 0 - ): - wandb.log( - { - "train/reward_mean": metric_sums["reward"] / max(metric_count, 1), - "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1), - "train/agent_prob": metric_sums["agent_prob"] - / max(metric_count, 1), - "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), - "train/coi_leakage": metric_sums["coi_leakage"] - / max(metric_count, 1), - "train/loss": metric_sums["loss"] / max(loss_count, 1), - "train/epsilon": epsilon_value, - "train/global_step": global_step, - }, - step=global_step, - ) - - if artifact_name is not None and global_step % checkpoint_interval == 0: - payload = serialization.to_bytes( - { - "params": train_state.params, - "target_params": target_params, - "opt_state": train_state.opt_state, - "global_step": global_step, - "epsilon": epsilon_value, - } - ) - log_checkpoint_bytes( - artifact_name, - file_name="jax_dqn_state.msgpack", - payload=payload, - metadata={ - "step": global_step, - "algo": "dqn", - }, - ) - if _stop_requested.is_set(): - break - - denom = float(metric_count) if metric_count > 0 else 1.0 - metrics = { - "train/reward_mean": float(metric_sums["reward"] / denom), - "train/revenue_mean": float(metric_sums["revenue"] / denom), - "train/agent_prob": float(metric_sums["agent_prob"] / denom), - "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), - "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), - "train/loss": float(metric_sums["loss"] / max(loss_count, 1)), - "train/global_step": total_steps, - } - - eval_metrics = _evaluate_q_network( - network=q_net, - params=train_state.params, - env=env, - episodes=run_cfg["eval_episodes"], - seed=run_cfg["seed"] + 7, - ) - metrics.update(eval_metrics) - - model_dir = Path(run_cfg["model_dir"]) - model_dir.mkdir(parents=True, exist_ok=True) - model_path = model_dir / "phantom_dqn_jax.msgpack" - model_path.write_bytes(serialization.to_bytes(train_state.params)) - metrics["model/path"] = str(model_path) - return { - "params": train_state.params, - "target_params": target_params, - }, metrics - - -def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: - run_cfg = dict(cfg) - env = _make_env(run_cfg) - action_dim = env.action_space_n() - n_bins = max(2, int(run_cfg["q_bins"])) - n_products = int(run_cfg["n_products"]) - - q_table = jnp.zeros((n_bins, n_bins, n_bins, action_dim), dtype=jnp.float32) - demand_bins = jnp.linspace(0.0, 100.0, n_bins + 1, dtype=jnp.float32)[1:-1] - price_bins = jnp.linspace( - float(run_cfg["price_low"]), - float(run_cfg["price_high"]), - n_bins + 1, - dtype=jnp.float32, - )[1:-1] - - rng = jax.random.PRNGKey(run_cfg["seed"]) - rng, reset_key = jax.random.split(rng) - obs, env_state = env.reset(reset_key) - - epsilon_value = float(run_cfg["eps_start"]) - start_step = 0 - artifact_name = None - - if HAS_WANDB and wandb.run is not None: - sweep_id = getattr(wandb.run, "sweep_id", None) - artifact_name = checkpoint_artifact_name( - run_cfg, - backend="jax", - sweep_id=sweep_id, - ) - restored = download_latest_checkpoint( - artifact_name, - file_name="jax_qtable_state.msgpack", - ) - if restored is not None: - checkpoint_path, metadata = restored - template = { - "q_table": q_table, - "global_step": 0, - "epsilon": epsilon_value, - } - payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) - q_table = payload["q_table"] - start_step = int(payload.get("global_step", metadata.get("step", 0))) - start_step = max(0, min(start_step, int(run_cfg["total_timesteps"]))) - epsilon_value = float(payload.get("epsilon", epsilon_value)) - - metric_sums = { - "reward": 0.0, - "revenue": 0.0, - "agent_prob": 0.0, - "alpha_adv": 0.0, - "coi_leakage": 0.0, - } - metric_count = 0 - - total_steps = int(run_cfg["total_timesteps"]) - checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"])) - - for global_step in range(start_step + 1, total_steps + 1): - s0, s1, s2 = _encode_qtable_state( - obs, - n_products=n_products, - demand_bins=demand_bins, - price_bins=price_bins, - ) - state_q = q_table[s0, s1, s2] - - rng, eps_key, action_key, step_key, reset_key = jax.random.split(rng, 5) - do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value)) - if do_explore: - action = jax.random.randint( - action_key, shape=(), minval=0, maxval=action_dim - ) - else: - action = jnp.argmax(state_q) - - next_obs, next_state, reward, done, info = env.step(step_key, env_state, action) - ns0, ns1, ns2 = _encode_qtable_state( - next_obs, - n_products=n_products, - demand_bins=demand_bins, - price_bins=price_bins, - ) - - best_next = jnp.max(q_table[ns0, ns1, ns2]) - done_f = done.astype(jnp.float32) - td_target = reward + run_cfg["gamma"] * (1.0 - done_f) * best_next - old_value = q_table[s0, s1, s2, action] - new_value = old_value + run_cfg["q_lr"] * (td_target - old_value) - q_table = q_table.at[s0, s1, s2, action].set(new_value) - - epsilon_value = max( - float(run_cfg["eps_end"]), - epsilon_value * float(run_cfg["eps_decay"]), - ) - - metric_count += 1 - metric_sums["reward"] += _scalar(reward) - metric_sums["revenue"] += _scalar(info["revenue"]) - metric_sums["agent_prob"] += _scalar(info["agent_prob"]) - metric_sums["alpha_adv"] += _scalar(info["alpha_adv"]) - metric_sums["coi_leakage"] += _scalar(info["coi_leakage"]) - - if bool(np.asarray(done)): - obs, env_state = env.reset(reset_key) - else: - obs, env_state = next_obs, next_state - - if ( - HAS_WANDB - and wandb.run is not None - and global_step % int(run_cfg["log_freq"]) == 0 - ): - wandb.log( - { - "train/reward_mean": metric_sums["reward"] / max(metric_count, 1), - "train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1), - "train/agent_prob": metric_sums["agent_prob"] - / max(metric_count, 1), - "train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1), - "train/coi_leakage": metric_sums["coi_leakage"] - / max(metric_count, 1), - "train/epsilon": epsilon_value, - "train/global_step": global_step, - }, - step=global_step, - ) - - if artifact_name is not None and global_step % checkpoint_interval == 0: - payload = serialization.to_bytes( - { - "q_table": q_table, - "global_step": global_step, - "epsilon": epsilon_value, - } - ) - log_checkpoint_bytes( - artifact_name, - file_name="jax_qtable_state.msgpack", - payload=payload, - metadata={ - "step": global_step, - "algo": "qtable", - }, - ) - - denom = float(metric_count) if metric_count > 0 else 1.0 - metrics = { - "train/reward_mean": float(metric_sums["reward"] / denom), - "train/revenue_mean": float(metric_sums["revenue"] / denom), - "train/agent_prob": float(metric_sums["agent_prob"] / denom), - "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), - "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), - "train/global_step": total_steps, - } - - eval_metrics = _evaluate_q_table( - q_table=q_table, - env=env, - episodes=run_cfg["eval_episodes"], - seed=run_cfg["seed"] + 7, - n_products=n_products, - demand_bins=demand_bins, - price_bins=price_bins, - ) - metrics.update(eval_metrics) - - model_dir = Path(run_cfg["model_dir"]) - model_dir.mkdir(parents=True, exist_ok=True) - model_path = model_dir / "phantom_qtable_jax.msgpack" - model_path.write_bytes(serialization.to_bytes(q_table)) - metrics["model/path"] = str(model_path) - return {"q_table": q_table}, metrics - - -def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: - if not HAS_JAX_STACK: - raise ImportError( - "JAX path requires jax, flax, optax, and distrax. " - "Install engine/jax/requirements.txt on this machine first." - ) - - _init_jax_distributed() - _stop_requested.clear() - run_cfg = _jax_cfg(cfg) - algo = run_cfg["algo"] - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGTERM, lambda *_: _stop_requested.set()) - - if algo in {"ppo", "a2c"}: - return _train_actor_critic(run_cfg, algo=algo) - if algo == "dqn": - return _train_dqn(run_cfg) - if algo == "qtable": - return _train_qtable(run_cfg) - raise ValueError(f"Unsupported JAX algo '{algo}'") diff --git a/engine/lib/__init__.py b/engine/lib/__init__.py index 823c572..31330cc 100644 --- a/engine/lib/__init__.py +++ b/engine/lib/__init__.py @@ -14,7 +14,6 @@ _EXPORTS: dict[str, tuple[str, str]] = { "EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"), "MetricsCallback": (".callbacks", "MetricsCallback"), "EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"), - "CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"), "ProviderBenchmark": (".providers", "ProviderBenchmark"), "ProviderResult": (".providers", "ProviderResult"), "BenchmarkConfig": (".providers", "BenchmarkConfig"), diff --git a/engine/lib/callbacks.py b/engine/lib/callbacks.py index a21fdfe..8377d80 100644 --- a/engine/lib/callbacks.py +++ b/engine/lib/callbacks.py @@ -1,150 +1,96 @@ -"""Training callbacks for W&B/TensorBoard logging - reads from info dict.""" +"""Training callbacks with algorithm-agnostic metric extraction.""" -from pathlib import Path +from typing import Any from stable_baselines3.common.callbacks import BaseCallback, EvalCallback import numpy as np -from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file - -try: - import wandb - - HAS_WANDB = True -except ImportError: - HAS_WANDB = False - class MetricsCallback(BaseCallback): - """Training metrics logger - reads info['economics'], logs to W&B.""" + """Collects interval train metrics from env info dictionaries.""" def __init__( - self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0 + self, + log_histograms: bool = False, + log_freq: int = 100, + verbose: int = 0, ): super().__init__(verbose) self.log_histograms = log_histograms - self.log_freq = log_freq - self._episode_revenues: list[float] = [] - - def _on_step(self) -> bool: - if not HAS_WANDB or wandb.run is None: - return True - - for info in self.locals.get("infos", []): - if "economics" not in info: - continue - - econ = info["economics"] - t = self.num_timesteps - - payload = { - "train/revenue_step": econ["revenue"], - "train/margin_step": econ["margin"], - "train/coi_level": econ["coi_level"], - "train/regret_step": econ["regret"], - } - if "coi_mix" in econ: - payload["train/coi_mix"] = econ["coi_mix"] - if "coi_base" in econ: - payload["train/coi_base"] = econ["coi_base"] - if "coi_leakage" in econ: - payload["train/coi_leakage"] = econ["coi_leakage"] - if "coi_penalty" in econ: - payload["train/coi_penalty"] = econ["coi_penalty"] - wandb.log(payload, step=t) - - self._episode_revenues.append(econ["revenue"]) - - # histograms at log_freq intervals - if self.log_histograms and self.num_timesteps % self.log_freq == 0: - for info in self.locals.get("infos", []): - if "prices" in info: - wandb.log( - {"distributions/prices": wandb.Histogram(info["prices"])}, - step=self.num_timesteps, - ) - if "demand" in info: - wandb.log( - {"distributions/demand": wandb.Histogram(info["demand"])}, - step=self.num_timesteps, - ) - - return True - - def _on_rollout_end(self) -> None: - if not HAS_WANDB or wandb.run is None or not self._episode_revenues: - return - wandb.log( - { - "train/revenue_rollout_mean": np.mean(self._episode_revenues), - "train/revenue_rollout_total": np.sum(self._episode_revenues), - }, - step=self.num_timesteps, - ) - self._episode_revenues = [] - - -class CheckpointArtifactCallback(BaseCallback): - """Periodic SB3 checkpoint uploader backed by W&B artifacts.""" - - def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0): - super().__init__(verbose) - self.cfg = dict(cfg) - self.interval = max(1, int(interval)) - self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models"))) - self.model_dir.mkdir(parents=True, exist_ok=True) - self._next_checkpoint = self.interval - self._last_saved_step = -1 - - def _artifact_name(self) -> str: - sweep_id = ( - getattr(wandb.run, "sweep_id", None) - if HAS_WANDB and wandb.run is not None - else None - ) - return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id) - - def _checkpoint_file(self) -> Path: - algo = str(self.cfg.get("algo", "model")) - base = self.model_dir / f"phantom_{algo}_checkpoint" - self.model.save(str(base)) - return base.with_suffix(".zip") - - def _save_checkpoint(self) -> None: - if not HAS_WANDB or wandb.run is None: - return - step = int(self.num_timesteps) - if step <= self._last_saved_step: - return - checkpoint_path = self._checkpoint_file() - metadata = { - "step": step, - "algo": str(self.cfg.get("algo", "unknown")), - "sweep_id": getattr(wandb.run, "sweep_id", None), + self.log_freq = max(1, int(log_freq)) + self._window_sums = { + "train/revenue_mean": 0.0, + "train/margin_mean": 0.0, + "train/coi_level_mean": 0.0, + "train/regret_mean": 0.0, + "train/coi_mix": 0.0, + "train/coi_base": 0.0, + "train/coi_leakage": 0.0, + "train/coi_penalty": 0.0, } - saved = log_checkpoint_file( - self._artifact_name(), - file_path=checkpoint_path, - artifact_file_name=checkpoint_path.name, - metadata=metadata, - ) - if saved: - self._last_saved_step = step + self._window_count = 0 + self.events: list[dict[str, Any]] = [] + + def _accumulate(self, info: dict[str, Any]) -> None: + econ = info.get("economics") + if not isinstance(econ, dict): + return + self._window_sums["train/revenue_mean"] += float(econ.get("revenue", 0.0)) + self._window_sums["train/margin_mean"] += float(econ.get("margin", 0.0)) + self._window_sums["train/coi_level_mean"] += float(econ.get("coi_level", 0.0)) + self._window_sums["train/regret_mean"] += float(econ.get("regret", 0.0)) + if "coi_mix" in econ: + self._window_sums["train/coi_mix"] += float(econ.get("coi_mix", 0.0)) + if "coi_base" in econ: + self._window_sums["train/coi_base"] += float(econ.get("coi_base", 0.0)) + if "coi_leakage" in econ: + self._window_sums["train/coi_leakage"] += float( + econ.get("coi_leakage", 0.0) + ) + if "coi_penalty" in econ: + self._window_sums["train/coi_penalty"] += float( + econ.get("coi_penalty", 0.0) + ) + self._window_count += 1 + + def _flush(self, step: int) -> None: + if self._window_count <= 0: + return + denom = float(self._window_count) + payload = { + key: (value / denom) + for key, value in self._window_sums.items() + if value != 0.0 + or key + in { + "train/revenue_mean", + "train/margin_mean", + "train/coi_level_mean", + "train/regret_mean", + } + } + payload["train/global_step"] = int(step) + self.events.append(payload) + for key in self._window_sums: + self._window_sums[key] = 0.0 + self._window_count = 0 def _on_step(self) -> bool: - if self.num_timesteps < self._next_checkpoint: - return True - self._save_checkpoint() - while self._next_checkpoint <= self.num_timesteps: - self._next_checkpoint += self.interval + for info in self.locals.get("infos", []): + if isinstance(info, dict): + self._accumulate(info) + + if self.num_timesteps % self.log_freq == 0: + self._flush(step=self.num_timesteps) + return True def _on_training_end(self) -> None: - self._save_checkpoint() + self._flush(step=self.num_timesteps) class EvalMetricsCallback(EvalCallback): - """Deterministic evaluation - true performance without exploration noise.""" + """Deterministic evaluation collector detached from logging backends.""" def __init__( self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs @@ -153,23 +99,19 @@ class EvalMetricsCallback(EvalCallback): eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs ) self._eval_revenues: list[float] = [] + self.events: list[dict[str, float | int]] = [] def _on_step(self) -> bool: result = super()._on_step() - - if not HAS_WANDB or wandb.run is None: - return result - - # log eval metrics after evaluation runs if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): - wandb.log( + self.events.append( { - "eval/reward_mean": self.last_mean_reward, - "eval/revenue_mean": np.mean(self._eval_revenues) + "eval/reward_mean": float(self.last_mean_reward), + "eval/revenue_mean": float(np.mean(self._eval_revenues)) if self._eval_revenues - else 0, - }, - step=self.num_timesteps, + else 0.0, + "train/global_step": int(self.num_timesteps), + } ) self._eval_revenues = [] diff --git a/engine/orchestrators/train.py b/engine/orchestrators/train.py index 6b0f539..81ebdb5 100644 --- a/engine/orchestrators/train.py +++ b/engine/orchestrators/train.py @@ -31,26 +31,20 @@ def _print_local_metrics(metrics: dict[str, Any]) -> None: print("PHANTOM_METRICS:" + json.dumps(metrics)) -def _should_print_local(spec: TrainSpec) -> bool: - if not spec.runtime.use_jax: - return True - try: - import jax - - return int(jax.process_index()) == 0 - except Exception: - return True - - -def _is_non_primary_jax_worker(spec: TrainSpec) -> bool: - if not spec.runtime.use_jax: - return False - try: - import jax - - return int(jax.process_count()) > 1 and int(jax.process_index()) != 0 - except Exception: - return False +def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None: + if not events: + return + period = max(1, int(log_freq)) + last_logged_step = -period + for event in sorted( + [evt for evt in events if isinstance(evt, dict)], + key=lambda evt: int(evt.get("train/global_step", 0)), + ): + step = int(event.get("train/global_step", 0)) + if step <= 0 or (step - last_logged_step) < period: + continue + log_metrics(event, step=step) + last_logged_step = step def run_train_once( @@ -65,10 +59,9 @@ def run_train_once( extra_tags: Sequence[str], ) -> dict[str, Any]: wandb = get_wandb_module() - if no_wandb or wandb is None or _is_non_primary_jax_worker(spec): + if no_wandb or wandb is None: result = run_train(spec) - if _should_print_local(spec): - _print_local_metrics(result.metrics) + _print_local_metrics(result.metrics) return result.metrics mode = "offline" if offline else "online" @@ -95,6 +88,7 @@ def run_train_once( try: result = run_train(spec) + _log_train_events(result.events, spec.runtime.log_freq) metrics = result.metrics step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) log_metrics(metrics, step=step) @@ -122,6 +116,7 @@ def run_with_active_sweep_run( ) update_run_config({**spec.to_flat_dict(), **metadata}) result = run_train(spec) + _log_train_events(result.events, spec.runtime.log_freq) metrics = result.metrics step = int(metrics.get("train/global_step", spec.runtime.total_timesteps)) log_metrics(metrics, step=step) diff --git a/engine/project.json b/engine/project.json index 3cf3571..2a78f1d 100644 --- a/engine/project.json +++ b/engine/project.json @@ -81,44 +81,6 @@ "command": "bash scripts/nx_research.sh docker-train-publish", "cwd": "." } - }, - "train-tpu-pod": { - "executor": "nx:run-commands", - "options": { - "command": "bash scripts/nx_research.sh train-tpu-pod", - "cwd": "." - } - }, - "train-tpu-vm-prepare": { - "executor": "nx:run-commands", - "options": { - "command": "bash scripts/nx_research.sh train-tpu-vm-prepare", - "cwd": "." - } - }, - "train-tpu-vm-run": { - "executor": "nx:run-commands", - "options": { - "command": "bash scripts/nx_research.sh train-tpu-vm-run", - "cwd": "." - } - }, - "train-tpu-vm": { - "executor": "nx:run-commands", - "dependsOn": [ - "train-tpu-vm-prepare" - ], - "options": { - "command": "bash scripts/nx_research.sh train-tpu-vm-run", - "cwd": "." - } - }, - "train-tpu-vm-sweep": { - "executor": "nx:run-commands", - "options": { - "command": "bash scripts/nx_research.sh train-tpu-vm-sweep", - "cwd": "." - } } }, "tags": [ diff --git a/engine/spec.py b/engine/spec.py index f72fdd0..7c9f059 100644 --- a/engine/spec.py +++ b/engine/spec.py @@ -106,11 +106,6 @@ class OptimizerSpec: eps_decay: float = 0.9995 arch: str = "small" activation: str = "relu" - jax_num_envs: int = 16 - jax_num_steps: int = 128 - jax_num_minibatches: int = 4 - jax_update_epochs: int = 4 - jax_anneal_lr: bool = True vf_coef: float = 0.5 max_grad_norm: float = 0.5 @@ -125,7 +120,6 @@ class RuntimeSpec: checkpoint_interval: int = 200_000 model_dir: str = "engine/models" log_freq: int = 100 - use_jax: bool = False @dataclass(frozen=True) @@ -156,7 +150,6 @@ class TrainSpec: "model_dir": self.runtime.model_dir, "backend": self.runtime.backend, "device": self.runtime.device, - "use_jax": self.runtime.use_jax, "checkpoint_interval": self.runtime.checkpoint_interval, "n_products": self.env.n_products, "N": self.env.n_sessions, @@ -197,11 +190,6 @@ class TrainSpec: "eps_decay": self.optimizer.eps_decay, "arch": self.optimizer.arch, "activation": self.optimizer.activation, - "jax_num_envs": self.optimizer.jax_num_envs, - "jax_num_steps": self.optimizer.jax_num_steps, - "jax_num_minibatches": self.optimizer.jax_num_minibatches, - "jax_update_epochs": self.optimizer.jax_update_epochs, - "jax_anneal_lr": self.optimizer.jax_anneal_lr, "vf_coef": self.optimizer.vf_coef, "max_grad_norm": self.optimizer.max_grad_norm, "robust_eval_enabled": self.eval.robust_eval_enabled, @@ -223,14 +211,11 @@ class TrainSpec: base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto")) ) - requested_jax = _truthy(base.get("use_jax")) or _truthy( - runtime_env.get("PHANTOM_USE_JAX") - ) - backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower() + backend = str(base.get("backend", "sb3")).lower() if backend == "auto": - backend = "jax" if requested_jax else "sb3" - if backend == "jax": - requested_jax = True + backend = "sb3" + if backend != "sb3": + backend = "sb3" no_robust = _truthy(base.get("no_robust")) if no_robust: @@ -284,11 +269,6 @@ class TrainSpec: eps_decay=float(base["eps_decay"]), arch=str(base["arch"]), activation=str(base["activation"]), - jax_num_envs=int(base["jax_num_envs"]), - jax_num_steps=int(base["jax_num_steps"]), - jax_num_minibatches=int(base["jax_num_minibatches"]), - jax_update_epochs=int(base["jax_update_epochs"]), - jax_anneal_lr=_truthy(base.get("jax_anneal_lr")), vf_coef=float(base["vf_coef"]), max_grad_norm=float(base["max_grad_norm"]), ), @@ -301,7 +281,6 @@ class TrainSpec: checkpoint_interval=int(base["checkpoint_interval"]), model_dir=str(base["model_dir"]), log_freq=int(base["log_freq"]), - use_jax=requested_jax, ), eval=EvalSpec( eval_freq=int(base["eval_freq"]), diff --git a/engine/sweeps/tpu_jax.yaml b/engine/sweeps/tpu_jax.yaml deleted file mode 100644 index 2e5de08..0000000 --- a/engine/sweeps/tpu_jax.yaml +++ /dev/null @@ -1,93 +0,0 @@ -method: bayes -metric: - name: objective/score - goal: maximize -command: - - ${env} - - python - - -m - - engine.train -parameters: - # fixed: always use JAX backend so TPU chips are actually exercised - use_jax: - value: true - # all four algos have JAX implementations - algo: - values: [ppo, a2c, dqn, qtable] - total_timesteps: - values: [50000, 80000, 120000] - checkpoint_interval: - value: 200000 - seed: - values: [13, 42, 77] - n_products: - values: [8, 10, 12] - # COI framework parameters -- primary research variables - alpha: - distribution: uniform - min: 0.1 - max: 0.6 - lambda_coi: - distribution: uniform - min: 0.05 - max: 0.6 - robust_radius: - distribution: uniform - min: 0.0 - max: 0.3 - robust_points: - values: [3, 5, 7] - info_value: - distribution: uniform - min: 0.5 - max: 2.0 - revenue_weight: - values: [0.005, 0.01, 0.02] - # shared hyperparameters - learning_rate: - distribution: log_uniform_values - min: 1.0e-5 - max: 1.0e-3 - gamma: - values: [0.97, 0.99, 0.995] - # JAX parallelism -- key lever for TPU throughput - jax_num_envs: - values: [8, 16, 32] - jax_num_steps: - values: [64, 128, 256] - jax_num_minibatches: - values: [2, 4, 8] - jax_update_epochs: - values: [2, 4, 8] - # PPO/A2C specific - gae_lambda: - values: [0.9, 0.95, 0.98] - clip_range: - values: [0.1, 0.2, 0.3] - ent_coef: - values: [0.0, 0.005, 0.01] - # DQN specific - buffer_size: - values: [20000, 50000, 100000] - batch_size: - values: [128, 256, 512] - learning_starts: - values: [500, 1000, 3000] - exploration_fraction: - values: [0.1, 0.2, 0.3] - exploration_final_eps: - values: [0.01, 0.03, 0.05] - # QTable specific - q_lr: - values: [0.03, 0.05, 0.1, 0.2] - eps_end: - values: [0.02, 0.05, 0.1] - eps_decay: - values: [0.999, 0.9995, 0.9999] - # action space - action_levels: - values: [7, 9, 11] - action_scale_low: - values: [0.75, 0.8, 0.85] - action_scale_high: - values: [1.15, 1.2, 1.25] diff --git a/engine/sweeps/tpu_pod.yaml b/engine/sweeps/tpu_pod.yaml deleted file mode 100644 index d34dfb1..0000000 --- a/engine/sweeps/tpu_pod.yaml +++ /dev/null @@ -1,64 +0,0 @@ -method: bayes -metric: - name: objective/score - goal: maximize -command: - - ${env} - - python - - -m - - engine.train -parameters: - use_jax: - value: true - # pmap requires all workers to compile the same computation graph shape, - # so structural params are fixed -- only research/scalar params are swept - algo: - values: [ppo, a2c] - jax_num_envs: - value: 32 - jax_num_steps: - value: 128 - jax_num_minibatches: - value: 4 - jax_update_epochs: - value: 4 - total_timesteps: - value: 100000 - checkpoint_interval: - value: 200000 - n_products: - value: 10 - action_levels: - value: 9 - # research parameters -- primary sweep targets - alpha: - distribution: uniform - min: 0.1 - max: 0.6 - lambda_coi: - distribution: uniform - min: 0.05 - max: 0.6 - robust_radius: - distribution: uniform - min: 0.0 - max: 0.3 - info_value: - distribution: uniform - min: 0.5 - max: 2.0 - revenue_weight: - values: [0.005, 0.01, 0.02] - # training hyperparameters - learning_rate: - distribution: log_uniform_values - min: 1.0e-5 - max: 1.0e-3 - gamma: - values: [0.97, 0.99, 0.995] - gae_lambda: - values: [0.9, 0.95, 0.98] - clip_range: - values: [0.1, 0.2, 0.3] - ent_coef: - values: [0.0, 0.005, 0.01] diff --git a/engine/train.py b/engine/train.py index 90ac991..7dd2f68 100644 --- a/engine/train.py +++ b/engine/train.py @@ -7,14 +7,6 @@ from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once from .spec import TrainSpec -def _truthy(value: str | bool | None) -> bool: - if isinstance(value, bool): - return value - if value is None: - return False - return str(value).strip().lower() in {"1", "true", "yes", "on"} - - def _parse_tags(raw: str | None) -> list[str]: if raw is None: return [] @@ -55,7 +47,7 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--group", type=str) parser.add_argument("--tags", type=str) - parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto") + parser.add_argument("--backend", choices=["auto", "sb3"], default="auto") parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"]) parser.add_argument("--seed", type=int) parser.add_argument("--total-timesteps", type=int) @@ -111,13 +103,6 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--eval-freq", type=int) parser.add_argument("--eval-episodes", type=int) - parser.add_argument("--jax", action="store_true") - parser.add_argument("--jax-num-envs", type=int) - parser.add_argument("--jax-num-steps", type=int) - parser.add_argument("--jax-num-minibatches", type=int) - parser.add_argument("--jax-update-epochs", type=int) - parser.add_argument("--jax-anneal-lr", type=str) - parser.add_argument("--sweep-agent", action="store_true") parser.add_argument("--sweep-id", type=str) parser.add_argument("--count", type=int, default=0) @@ -127,9 +112,6 @@ def _build_parser() -> argparse.ArgumentParser: def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: - jax_anneal_lr = ( - _truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None - ) backend = None if args.backend == "auto" else args.backend overrides = { @@ -185,12 +167,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]: "max_grad_norm": args.max_grad_norm, "eval_freq": args.eval_freq, "eval_episodes": args.eval_episodes, - "use_jax": args.jax or None, - "jax_num_envs": args.jax_num_envs, - "jax_num_steps": args.jax_num_steps, - "jax_num_minibatches": args.jax_num_minibatches, - "jax_update_epochs": args.jax_update_epochs, - "jax_anneal_lr": jax_anneal_lr, } return {key: value for key, value in overrides.items() if value is not None} diff --git a/engine/train_core.py b/engine/train_core.py index 8b29f45..6245030 100644 --- a/engine/train_core.py +++ b/engine/train_core.py @@ -12,17 +12,14 @@ class TrainResult: spec: TrainSpec metrics: dict[str, Any] artifacts: dict[str, str] + events: list[dict[str, Any]] def run_train(spec: TrainSpec) -> TrainResult: cfg = spec.to_flat_dict() algo = spec.algorithm.name - if spec.runtime.use_jax or spec.runtime.backend == "jax": - from .backends.jax import train_jax_backend - - _, raw_metrics = train_jax_backend(cfg) - elif algo == "qtable": + if algo == "qtable": from .backends.qtable import train_qtable _, raw_metrics = train_qtable(cfg) @@ -31,10 +28,13 @@ def run_train(spec: TrainSpec) -> TrainResult: _, raw_metrics = train_sb3(cfg) + events_raw = raw_metrics.pop("_train_events", []) + events = [evt for evt in events_raw if isinstance(evt, dict)] + metrics = canonicalize_metrics(raw_metrics, spec) artifacts: dict[str, str] = {} model_path = raw_metrics.get("model/path") if isinstance(model_path, str): artifacts["model/path"] = model_path - return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts) + return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts, events=events) diff --git a/scripts/nx_research.sh b/scripts/nx_research.sh index 5e72e3f..78117e3 100644 --- a/scripts/nx_research.sh +++ b/scripts/nx_research.sh @@ -108,49 +108,6 @@ PY image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}" docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" . docker push "$image_ref:gpu-latest" - docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" . - docker push "$image_ref:tpu-latest" - ;; - train-tpu-pod) - load_sweep_env - require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" - require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" - require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" - gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all - gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh" - ;; - train-tpu-vm-prepare) - require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" - TPU_NAME="$TPU_NAME" \ - TPU_ZONE="${TPU_ZONE:-us-central2-b}" \ - TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \ - LOCAL_REPO_DIR="$PWD" \ - REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \ - sh scripts/tpu_sync_repo.sh - gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all - ;; - train-tpu-vm-run) - load_sweep_env - require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" - require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000" - gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh" - ;; - train-tpu-vm-sweep) - load_sweep_env - require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong" - require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123" - require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file" - args=( - --sweep-id "$SWEEP_ID" - --tpu-name "$TPU_NAME" - --tpu-zone "${TPU_ZONE:-us-central2-b}" - --tpu-project "${TPU_PROJECT:-phantom-trc}" - --tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}" - ) - if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then - args+=(--count "$AGENT_COUNT") - fi - WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}" ;; *) printf '%s\n' "Unknown research command: $cmd" >&2 diff --git a/scripts/tpu_pod_run.sh b/scripts/tpu_pod_run.sh deleted file mode 100755 index 8e1d722..0000000 --- a/scripts/tpu_pod_run.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env sh -# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`. -# Authenticates with Artifact Registry using the VM's service account metadata token, -# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker. -# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount. -# Required env vars: WANDB_API_KEY, SWEEP_ID -# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends) -set -eu - -IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest" -AGENT_COUNT="${AGENT_COUNT:-1}" - -# use VM service account — no manual key needed on the pod -TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \ - "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \ - | python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])') - -echo "$TOKEN" | sudo docker login -u oauth2accesstoken \ - --password-stdin https://us-central1-docker.pkg.dev - -sudo docker pull "$IMAGE" - -# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips) -# --network host lets JAX reach the other pod workers for distributed init -sudo docker run --rm \ - --privileged \ - --network host \ - --volume /dev:/dev \ - -e WANDB_API_KEY="$WANDB_API_KEY" \ - -e SWEEP_ID="$SWEEP_ID" \ - -e AGENT_COUNT="$AGENT_COUNT" \ - "$IMAGE" diff --git a/scripts/tpu_sync_repo.sh b/scripts/tpu_sync_repo.sh deleted file mode 100644 index a26e241..0000000 --- a/scripts/tpu_sync_repo.sh +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env sh -set -eu - -TPU_NAME="${TPU_NAME:?TPU_NAME is required}" -TPU_ZONE="${TPU_ZONE:-us-central2-b}" -TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" -LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}" -REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}" -ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}" - -FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)" -CLEANUP_LIST=true - -cleanup() { - if [ "$CLEANUP_LIST" = "true" ]; then - rm -f "$FILE_LIST" - fi -} -trap cleanup EXIT - -if [ ! -d "$LOCAL_REPO_DIR" ]; then - echo "local repo directory not found: $LOCAL_REPO_DIR" - exit 1 -fi - -if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then - git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST" - python3 - "$FILE_LIST" <<'PY' -import sys -from pathlib import Path - -file_list = Path(sys.argv[1]) -skip_prefixes = ( - "wandb/", - ".venv/", - "venv/", - "node_modules/", - ".next/", - ".turbo/", - "__pycache__/", - ".mypy_cache/", - ".pytest_cache/", - ".ruff_cache/", - "paper/build/", - "tests/e2e/test-results/", -) - -rows = file_list.read_text().splitlines() -kept = [ - row - for row in rows - if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes) -] -file_list.write_text("\n".join(kept) + ("\n" if kept else "")) -PY - tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST" -else - tar \ - --exclude-vcs \ - --exclude=".venv" --exclude="*/.venv" \ - --exclude="venv" --exclude="*/venv" \ - --exclude="node_modules" --exclude="*/node_modules" \ - --exclude=".next" --exclude="*/.next" \ - --exclude=".turbo" --exclude="*/.turbo" \ - --exclude="__pycache__" --exclude="*/__pycache__" \ - --exclude=".mypy_cache" --exclude="*/.mypy_cache" \ - --exclude=".pytest_cache" --exclude="*/.pytest_cache" \ - --exclude=".ruff_cache" --exclude="*/.ruff_cache" \ - --exclude="wandb" --exclude="*/wandb" \ - --exclude="paper/build" \ - --exclude="tests/e2e/test-results" \ - -czf "$ARCHIVE_PATH" \ - -C "$LOCAL_REPO_DIR" . -fi - -gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \ - --zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all - -gcloud compute tpus tpu-vm ssh "$TPU_NAME" \ - --zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \ - --command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz" - -rm -f "$ARCHIVE_PATH" diff --git a/scripts/tpu_vm_sweep_agent.py b/scripts/tpu_vm_sweep_agent.py deleted file mode 100644 index b051b86..0000000 --- a/scripts/tpu_vm_sweep_agent.py +++ /dev/null @@ -1,211 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import gc -import json -import os -import re -import shlex -import shutil -import subprocess -import time -import resource -from pathlib import Path - -import wandb - - -CLI_MAP: dict[str, str] = { - "algo": "--algo", - "total_timesteps": "--total-timesteps", - "alpha": "--alpha", - "N": "--N", - "n_products": "--n-products", - "lambda_coi": "--lambda-coi", - "info_value": "--info-value", - "robust_radius": "--robust-radius", - "robust_points": "--robust-points", - "no_robust": "--no-robust", - "learning_rate": "--learning-rate", - "gamma": "--gamma", - "gae_lambda": "--gae-lambda", - "clip_range": "--clip-range", - "ent_coef": "--ent-coef", - "revenue_weight": "--revenue-weight", - "max_steps": "--max-steps", - "margin_floor": "--margin-floor", - "margin_floor_patience": "--margin-floor-patience", - "arch": "--arch", - "activation": "--activation", - "jax_num_envs": "--jax-num-envs", - "jax_num_steps": "--jax-num-steps", - "jax_num_minibatches": "--jax-num-minibatches", - "jax_update_epochs": "--jax-update-epochs", - "jax_anneal_lr": "--jax-anneal-lr", - "checkpoint_interval": "--checkpoint-interval", - "action_levels": "--action-levels", - "action_scale_low": "--action-scale-low", - "action_scale_high": "--action-scale-high", -} - - -def _to_cli_args(cfg: dict) -> str: - parts: list[str] = ["--jax", "--no-wandb"] - for key, flag in CLI_MAP.items(): - if key not in cfg: - continue - value = cfg[key] - if value is None: - continue - if isinstance(value, bool): - if key == "jax_anneal_lr": - parts.extend([flag, "true" if value else "false"]) - elif value: - parts.append(flag) - continue - parts.extend([flag, str(value)]) - return " ".join(shlex.quote(p) for p in parts) - - -_SENTINEL = "PHANTOM_METRICS:" - - -def _raise_nofile_limit(min_soft: int = 8192) -> None: - try: - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - target = min(hard, max(soft, min_soft)) - if target > soft: - resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard)) - except Exception: - return - - -def _extract_metrics(output: str) -> dict: - # fast path: look for the dedicated sentinel line emitted by run_local - for line in output.splitlines(): - if line.startswith(_SENTINEL): - try: - return json.loads(line[len(_SENTINEL) :]) - except Exception: - break - # fallback: scan for any JSON block containing eval/sweep keys; - # use greedy match to capture the largest possible block first - for block in re.findall(r"\{[^{}]*\}", output): - try: - obj = json.loads(block) - except Exception: - continue - if isinstance(obj, dict) and ( - "objective/score" in obj - or "eval/reward_mean" in obj - or "sweep/score" in obj - ): - return obj - return {} - - -def main() -> None: - _raise_nofile_limit() - p = argparse.ArgumentParser( - description="Run W&B sweep where each trial uses full TPU pod" - ) - p.add_argument("--sweep-id", required=True) - p.add_argument("--tpu-name", required=True) - p.add_argument("--tpu-zone", default="us-central2-b") - p.add_argument("--tpu-project", default="phantom-trc") - p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM") - p.add_argument("--count", type=int, default=0) - p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1])) - args = p.parse_args() - - workdir = Path(args.workdir).resolve() - env = os.environ.copy() - wandb_root = workdir / ".wandb-agent" - wandb_root.mkdir(parents=True, exist_ok=True) - - prepare_cmd = [ - "make", - "train.tpu.vm.prepare", - f"TPU_NAME={args.tpu_name}", - f"TPU_ZONE={args.tpu_zone}", - f"TPU_PROJECT={args.tpu_project}", - f"TPU_REPO_DIR={args.tpu_repo_dir}", - ] - prepare = subprocess.run( - prepare_cmd, - cwd=workdir, - env=env, - text=True, - capture_output=False, - check=False, - ) - if prepare.returncode != 0: - raise RuntimeError("Failed to prepare TPU workers for sweep") - - def run_trial() -> None: - run = None - trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}" - trial_wandb_dir.mkdir(parents=True, exist_ok=True) - try: - run = wandb.init(dir=str(trial_wandb_dir)) - cfg = dict(wandb.config) - cli_args = _to_cli_args(cfg) - env_trial = dict(env) - env_trial["LOCAL_TRAIN_ARGS"] = cli_args - env_trial["WANDB_DIR"] = str(trial_wandb_dir) - env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache") - env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data") - - cmd = [ - "make", - "train.tpu.vm.run", - f"TPU_NAME={args.tpu_name}", - f"TPU_ZONE={args.tpu_zone}", - f"TPU_PROJECT={args.tpu_project}", - f"TPU_REPO_DIR={args.tpu_repo_dir}", - ] - - proc = subprocess.run( - cmd, - cwd=workdir, - env=env_trial, - text=True, - capture_output=True, - check=False, - ) - - if proc.stdout: - print(proc.stdout) - if proc.stderr: - print(proc.stderr) - - if proc.returncode != 0: - if run is not None: - run.summary["runner/exit_code"] = proc.returncode - raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}") - - metrics = _extract_metrics(proc.stdout) - if metrics: - wandb.log(metrics) - for k, v in metrics.items(): - run.summary[k] = v - run.summary["runner/exit_code"] = 0 - except Exception: - time.sleep(2) - raise - finally: - if run is not None and wandb.run is not None: - wandb.finish() - shutil.rmtree(trial_wandb_dir, ignore_errors=True) - gc.collect() - - wandb.agent( - args.sweep_id, - function=run_trial, - count=args.count if args.count > 0 else None, - ) - - -if __name__ == "__main__": - main() diff --git a/scripts/tpu_vm_train.sh b/scripts/tpu_vm_train.sh deleted file mode 100644 index 33c798e..0000000 --- a/scripts/tpu_vm_train.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env sh -set -eu - -REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}" -PYTHON_BIN="${PYTHON_BIN:-python3}" -TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}" -EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}" -INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}" - -if [ ! -d "$REPO_DIR" ]; then - echo "repo directory not found: $REPO_DIR" - exit 1 -fi - -cd "$REPO_DIR" - -if [ -d "wandb" ]; then - rm -rf wandb -fi - -# keep install idempotent and avoid re-installing jax/libtpu each run -if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then - $PYTHON_BIN -m pip install -r requirements.txt -fi -if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then - if [ -f "engine/jax/requirements.txt" ]; then - $PYTHON_BIN -m pip install -r engine/jax/requirements.txt - fi - $PYTHON_BIN -m pip install -U $EXTRA_PIP -fi - -if [ -n "${WANDB_API_KEY:-}" ]; then - if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then - $PYTHON_BIN -m pip install -U wandb - fi -fi - -if [ -n "${WANDB_API_KEY:-}" ]; then - export WANDB_API_KEY - exec $PYTHON_BIN -m engine.train $TRAIN_ARGS -fi - -exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb