diff --git a/Makefile b/Makefile index d7fd956..27ce523 100644 --- a/Makefile +++ b/Makefile @@ -8,12 +8,30 @@ VENV := .venv PYTHON := $(VENV)/bin/python PIP := $(VENV)/bin/pip PYTEST := $(VENV)/bin/pytest +TPU_NAME ?= phantom-tpu +TPU_ZONE ?= us-central2-b +TPU_TYPE ?= v4-32 +TPU_RUNTIME ?= tpu-vm-v4-base +TPU_PROJECT ?= phantom-trc +TPU_NETWORK ?= default +TPU_SUBNETWORK ?= default-us-central2 +TPU_USE_SPOT ?= 0 +TPU_EXTRA_CREATE_FLAGS ?= +TPU_WORKDIR ?= ~/PHANTOM +TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env +TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000 +TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html +TPU_VENV ?= .venv-tpu +TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline +TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,) +TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS) .DEFAULT_GOAL := help .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines" + @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*" + @echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot" $(BUILDDIR): mkdir -p paper/$(BUILDDIR) @@ -70,6 +88,97 @@ $(VENV): install: $(VENV) $(PIP) install -r requirements.txt +.PHONY: tpu.setup +tpu.setup: + @command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1) + @gcloud auth login --update-adc + @gcloud auth application-default login + @gcloud config set project "$(TPU_PROJECT)" + +.PHONY: tpu.check.zone +tpu.check.zone: + @case "$(TPU_ZONE)" in \ + europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \ + *) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \ + esac + +.PHONY: tpu.create.v4.ondemand +tpu.create.v4.ondemand: + $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2 + +.PHONY: tpu.create.v4.spot +tpu.create.v4.spot: + $(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2 + +.PHONY: tpu.create +tpu.create: tpu.check.zone + @if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \ + STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \ + echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \ + else \ + $(TPU_CREATE_CMD); \ + fi + +.PHONY: tpu.ensure +tpu.ensure: tpu.check.zone + @set -e; \ + STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \ + if [ -z "$$STATE" ]; then \ + echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \ + $(TPU_CREATE_CMD); \ + elif [ "$$STATE" = "READY" ]; then \ + echo "TPU VM $(TPU_NAME) is READY"; \ + elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \ + echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \ + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \ + $(TPU_CREATE_CMD); \ + else \ + echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \ + exit 1; \ + fi + +.PHONY: tpu.status +tpu.status: + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" + +.PHONY: tpu.ssh +tpu.ssh: + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" + +.PHONY: tpu.prepare +tpu.prepare: tpu.ensure + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)" + +.PHONY: tpu.deploy +tpu.deploy: tpu.prepare + @for p in $(TPU_SYNC_PATHS); do \ + if [ ! -e "$$p" ]; then continue; fi; \ + if [ -d "$$p" ]; then \ + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \ + else \ + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \ + fi; \ + done + +.PHONY: tpu.install +tpu.install: tpu.ensure + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)' + +.PHONY: tpu.check.remote +tpu.check.remote: tpu.ensure + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)' + +.PHONY: tpu.train +tpu.train: tpu.check.remote + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)' + +.PHONY: tpu.bootstrap +tpu.bootstrap: tpu.ensure tpu.deploy tpu.install + +.PHONY: tpu.delete +tpu.delete: + gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet + .PHONY: stats.lines stats.lines: @find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \ diff --git a/engine/jax/__init__.py b/engine/jax/__init__.py new file mode 100644 index 0000000..8b6f740 --- /dev/null +++ b/engine/jax/__init__.py @@ -0,0 +1,13 @@ +"""JAX-compatible training and environment modules for PHANTOM.""" + +from __future__ import annotations + +try: + import jax # noqa: F401 + import jax.numpy as jnp # noqa: F401 + + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + +__all__ = ["JAX_AVAILABLE"] diff --git a/engine/jax/checkpoint.py b/engine/jax/checkpoint.py new file mode 100644 index 0000000..c75c6bc --- /dev/null +++ b/engine/jax/checkpoint.py @@ -0,0 +1,49 @@ +"""Orbax checkpoint helpers for JAX training runs.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +try: + import orbax.checkpoint as ocp + + HAS_ORBAX = True +except ImportError: + HAS_ORBAX = False + + +def _require_orbax() -> None: + if not HAS_ORBAX: + raise ImportError( + "orbax-checkpoint is required for checkpoint support. " + "Install engine/jax/requirements.txt first." + ) + + +def create_manager(directory: str | Path, max_to_keep: int = 5): + _require_orbax() + root = Path(directory) + root.mkdir(parents=True, exist_ok=True) + options = ocp.CheckpointManagerOptions( + max_to_keep=max(1, int(max_to_keep)), create=True + ) + return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options) + + +def save(manager, *, step: int, payload: Any) -> bool: + _require_orbax() + return bool(manager.save(int(step), payload)) + + +def latest_step(manager) -> int | None: + _require_orbax() + return manager.latest_step() + + +def restore(manager, *, target: Any, step: int | None = None) -> Any: + _require_orbax() + step_to_restore = manager.latest_step() if step is None else int(step) + if step_to_restore is None: + return target + return manager.restore(step_to_restore, items=target) diff --git a/engine/jax/env.py b/engine/jax/env.py new file mode 100644 index 0000000..06542b1 --- /dev/null +++ b/engine/jax/env.py @@ -0,0 +1,287 @@ +"""JAX-native PHANTOM environment with robust contamination step.""" + +from __future__ import annotations + +from typing import NamedTuple + +try: + import jax + import jax.numpy as jnp +except ImportError as exc: # pragma: no cover + raise ImportError("engine.jax.env requires JAX") from exc + +from .primitives import ( + _sample_sessions_jax, + agent_probability_from_kl, + batch_kl, + compute_session_transitions, + load_transition_data, + purchase_flags, + reward_with_coi_penalty, + revenue_from_demand, + weighted_demand, +) + + +class EnvParams(NamedTuple): + n_products: int + n_sessions: int + max_episode_steps: int + max_session_steps: int + price_low: float + price_high: float + lambda_coi: float + info_value: float + robust_radius: float + margin_floor: float + margin_floor_patience: int + action_scales: jax.Array + alpha_nominal: float + alpha_candidates: jax.Array + human_T: jax.Array + agent_T: jax.Array + terminal_mask: jax.Array + purchase_mask: jax.Array + event_weights: jax.Array + start_idx: int + term_idx: int + + +class EnvState(NamedTuple): + prices: jax.Array + demand: jax.Array + step_count: jax.Array + low_margin_streak: jax.Array + last_agent_prob: jax.Array + last_alpha_adv: jax.Array + + +class CandidateEval(NamedTuple): + reward: jax.Array + revenue: jax.Array + demand: jax.Array + agent_prob: jax.Array + leakage: jax.Array + discount: jax.Array + n_purchases: jax.Array + n_agents: jax.Array + + +def make_env_params( + *, + n_products: int, + alpha: float, + n_sessions: int, + lambda_coi: float, + robust_radius: float, + robust_points: int, + info_value: float, + action_levels: int, + action_scale_low: float, + action_scale_high: float, + price_low: float, + price_high: float, + max_episode_steps: int, + max_session_steps: int = 40, + margin_floor: float = 0.05, + margin_floor_patience: int = 5, + prefer_behavior_data: bool = True, +) -> EnvParams: + transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax() + if robust_radius <= 0.0 or robust_points <= 1: + alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32) + else: + lo = max(0.0, float(alpha) - float(robust_radius)) + hi = min(1.0, float(alpha) + float(robust_radius)) + alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32) + + action_scales = jnp.linspace( + float(action_scale_low), + float(action_scale_high), + int(action_levels), + dtype=jnp.float32, + ) + return EnvParams( + n_products=int(n_products), + n_sessions=int(n_sessions), + max_episode_steps=int(max_episode_steps), + max_session_steps=int(max_session_steps), + price_low=float(price_low), + price_high=float(price_high), + lambda_coi=float(lambda_coi), + info_value=float(info_value), + robust_radius=float(robust_radius), + margin_floor=float(margin_floor), + margin_floor_patience=int(margin_floor_patience), + action_scales=action_scales, + alpha_nominal=float(alpha), + alpha_candidates=alpha_candidates, + human_T=jnp.asarray(transition.human_T), + agent_T=jnp.asarray(transition.agent_T), + terminal_mask=jnp.asarray(transition.terminal_mask), + purchase_mask=jnp.asarray(transition.purchase_mask), + event_weights=jnp.asarray(transition.event_weights), + start_idx=int(transition.start_idx), + term_idx=int(transition.term_idx), + ) + + +def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array: + return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)]) + + +def _decode_action( + prices: jax.Array, action: jax.Array, params: EnvParams +) -> jax.Array: + idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1) + scale = params.action_scales[idx] + next_prices = prices * scale + return jnp.clip(next_prices, params.price_low, params.price_high) + + +def _evaluate_candidate( + key: jax.Array, + alpha_candidate: jax.Array, + prices: jax.Array, + params: EnvParams, +) -> CandidateEval: + states, products, actors, lengths = _sample_sessions_jax( + key, + params.human_T, + params.agent_T, + params.terminal_mask, + params.start_idx, + params.term_idx, + alpha_candidate, + params.n_products, + params.n_sessions, + params.max_session_steps, + int(params.human_T.shape[0]), + ) + session_trans = compute_session_transitions( + states, lengths, int(params.human_T.shape[0]) + ) + delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T) + agent_probs = agent_probability_from_kl(delta_h, delta_a) + agent_prob = jnp.mean(agent_probs) + + demand = weighted_demand(states, products, params.n_products, params.event_weights) + revenue = revenue_from_demand(prices, demand) + reward, leakage, discount = reward_with_coi_penalty( + revenue, + agent_prob, + params.lambda_coi, + params.info_value, + ) + purchases = purchase_flags(states, params.purchase_mask) + return CandidateEval( + reward=reward, + revenue=revenue, + demand=demand, + agent_prob=agent_prob, + leakage=leakage, + discount=discount, + n_purchases=jnp.sum(purchases.astype(jnp.float32)), + n_agents=jnp.sum(actors.astype(jnp.float32)), + ) + + +def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]: + prices = jax.random.uniform( + key, + shape=(params.n_products,), + minval=params.price_low, + maxval=params.price_high, + ) + demand = jnp.zeros((params.n_products,), dtype=jnp.float32) + state = EnvState( + prices=prices, + demand=demand, + step_count=jnp.asarray(0, dtype=jnp.int32), + low_margin_streak=jnp.asarray(0, dtype=jnp.int32), + last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32), + last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32), + ) + return _flatten_obs(demand, prices), state + + +def step_env( + key: jax.Array, + state: EnvState, + action: jax.Array, + params: EnvParams, +) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]: + prices = _decode_action(state.prices, action, params) + n_candidates = params.alpha_candidates.shape[0] + cand_keys = jax.random.split(key, n_candidates) + evals = jax.vmap( + lambda k, a: _evaluate_candidate(k, a, prices, params), + in_axes=(0, 0), + )(cand_keys, params.alpha_candidates) + idx = jnp.argmin(evals.reward) + + demand = evals.demand[idx] + reward = evals.reward[idx] + revenue = evals.revenue[idx] + agent_prob = evals.agent_prob[idx] + leakage = evals.leakage[idx] + discount = evals.discount[idx] + n_purchases = evals.n_purchases[idx] + n_agents = evals.n_agents[idx] + alpha_adv = params.alpha_candidates[idx] + + step_count = state.step_count + 1 + avg_price = jnp.maximum(jnp.mean(prices), 1e-6) + avg_margin = (avg_price - params.price_low) / avg_price + next_streak = jnp.where( + avg_margin < params.margin_floor, state.low_margin_streak + 1, 0 + ) + + margin_collapsed = next_streak >= params.margin_floor_patience + done = (step_count >= params.max_episode_steps) | margin_collapsed + + next_state = EnvState( + prices=prices, + demand=demand, + step_count=step_count, + low_margin_streak=next_streak, + last_agent_prob=agent_prob, + last_alpha_adv=alpha_adv, + ) + obs = _flatten_obs(demand, prices) + info = { + "revenue": revenue, + "agent_prob": agent_prob, + "alpha_adv": alpha_adv, + "coi_leakage": leakage, + "coi_discount": discount, + "n_purchases": n_purchases, + "n_agents": n_agents, + "avg_margin": avg_margin, + } + return obs, next_state, reward, done, info + + +class PHANTOMJAXEnv: + def __init__(self, params: EnvParams): + self.params = params + + def reset(self, key: jax.Array, params: EnvParams | None = None): + return reset_env(key, self.params if params is None else params) + + def step( + self, + key: jax.Array, + state: EnvState, + action: jax.Array, + params: EnvParams | None = None, + ): + return step_env(key, state, action, self.params if params is None else params) + + def action_space_n(self, params: EnvParams | None = None) -> int: + p = self.params if params is None else params + return int(p.action_scales.shape[0]) + + def observation_dim(self, params: EnvParams | None = None) -> int: + p = self.params if params is None else params + return int(p.n_products * 2) diff --git a/engine/jax/primitives.py b/engine/jax/primitives.py new file mode 100644 index 0000000..8de4c2b --- /dev/null +++ b/engine/jax/primitives.py @@ -0,0 +1,493 @@ +"""JAX-compatible primitives for PHANTOM session simulation and separability.""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from typing import Mapping, Sequence + +import numpy as np + +try: + import jax + import jax.numpy as jnp + + JAX_AVAILABLE = True +except ImportError: + jax = None # type: ignore[assignment] + jnp = np # type: ignore[assignment] + JAX_AVAILABLE = False + + +STATE_START_KEYS = ("session_start", "start") +TERMINAL_EVENT_TOKENS = ( + "session_end", + "end", + "purchase_complete", + "checkout_start", + "checkout", +) +PURCHASE_EVENT_TOKENS = ( + "purchase_complete", + "purchase", + "checkout_start", + "checkout", +) + +CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5} +ACTION_CATEGORIES = { + "cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"}, + "dwell": { + "hover_title", + "hover_paragraph", + "hover_link", + "hover_over_title", + "hover_over_paragraph", + "hover_over_link", + "hover_over_button", + }, + "nav": { + "page_view", + "view_item", + "view", + "learn_more", + "learn_more_about_item", + "view_item_page", + "session_start", + }, + "filter": { + "search", + "filter_date", + "filter_price", + "sort", + "filter_for_date", + "filter_for_price", + "filter_for_amenities", + "sort_change", + }, +} +DEFAULT_ACTION_WEIGHTS = { + action: CATEGORY_WEIGHTS[group] + for group, actions in ACTION_CATEGORIES.items() + for action in actions +} + + +@dataclass(frozen=True) +class TransitionData: + """Dense transition kernels and per-state metadata.""" + + human_T: np.ndarray + agent_T: np.ndarray + terminal_mask: np.ndarray + purchase_mask: np.ndarray + event_weights: np.ndarray + event_names: tuple[str, ...] + start_idx: int + term_idx: int + + def to_jax(self) -> "TransitionData": + if not JAX_AVAILABLE: + return self + return TransitionData( + human_T=jnp.asarray(self.human_T), + agent_T=jnp.asarray(self.agent_T), + terminal_mask=jnp.asarray(self.terminal_mask), + purchase_mask=jnp.asarray(self.purchase_mask), + event_weights=jnp.asarray(self.event_weights), + event_names=self.event_names, + start_idx=int(self.start_idx), + term_idx=int(self.term_idx), + ) + + +@dataclass(frozen=True) +class SessionBatch: + states: np.ndarray + products: np.ndarray + actors: np.ndarray + lengths: np.ndarray + + +def _event_weight(name: str) -> float: + if name in DEFAULT_ACTION_WEIGHTS: + return float(DEFAULT_ACTION_WEIGHTS[name]) + if name.startswith("hover"): + return float(CATEGORY_WEIGHTS["dwell"]) + if name.startswith("filter") or name in {"search", "sort", "sort_change"}: + return float(CATEGORY_WEIGHTS["filter"]) + if name.startswith("add") or name in { + "checkout", + "checkout_start", + "purchase", + "remove_item", + "purchase_complete", + }: + return float(CATEGORY_WEIGHTS["cart"]) + if any(token in name for token in TERMINAL_EVENT_TOKENS): + return 0.0 + return float(CATEGORY_WEIGHTS["nav"]) + + +def _is_terminal(name: str) -> bool: + return any(token in name for token in TERMINAL_EVENT_TOKENS) + + +def _is_purchase(name: str) -> bool: + return any(token in name for token in PURCHASE_EVENT_TOKENS) + + +def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]: + names: set[str] = set() + for trans in transitions: + for src, dsts in trans.items(): + names.add(src) + names.update(dsts.keys()) + names.discard("__terminal__") + return tuple(sorted(names)) + + +def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray: + row_sums = matrix.sum(axis=1, keepdims=True) + dead_rows = np.isclose(row_sums.squeeze(-1), 0.0) + if np.any(dead_rows): + matrix[dead_rows] = 0.0 + matrix[dead_rows, term_idx] = 1.0 + row_sums = matrix.sum(axis=1, keepdims=True) + return matrix / np.maximum(row_sums, 1e-8) + + +def _dense_from_dict( + transitions: Mapping[str, Mapping[str, float]], + event_to_idx: Mapping[str, int], + term_idx: int, +) -> np.ndarray: + n_states = len(event_to_idx) + matrix = np.zeros((n_states, n_states), dtype=np.float32) + for src, dsts in transitions.items(): + i = event_to_idx.get(src) + if i is None: + continue + for dst, prob in dsts.items(): + j = event_to_idx.get(dst) + if j is None: + continue + matrix[i, j] += float(prob) + return _normalize_rows(matrix, term_idx) + + +def compile_transition_data( + human_transitions: Mapping[str, Mapping[str, float]], + agent_transitions: Mapping[str, Mapping[str, float]], +) -> TransitionData: + event_names = _collect_events(human_transitions, agent_transitions) + if not event_names: + return fallback_transition_data() + + event_names = tuple([*event_names, "__terminal__"]) + term_idx = len(event_names) - 1 + event_to_idx = {name: i for i, name in enumerate(event_names)} + + human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx) + agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx) + + terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool) + purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool) + event_weights = np.array( + [_event_weight(name) for name in event_names], dtype=np.float32 + ) + + terminal_mask[term_idx] = True + + for idx, is_term in enumerate(terminal_mask): + if not is_term: + continue + human_T[idx] = 0.0 + agent_T[idx] = 0.0 + human_T[idx, idx] = 1.0 + agent_T[idx, idx] = 1.0 + + start_idx = 0 + for key in STATE_START_KEYS: + if key in event_to_idx: + start_idx = int(event_to_idx[key]) + break + + return TransitionData( + human_T=human_T, + agent_T=agent_T, + terminal_mask=terminal_mask, + purchase_mask=purchase_mask, + event_weights=event_weights, + event_names=event_names, + start_idx=start_idx, + term_idx=term_idx, + ) + + +def fallback_transition_data() -> TransitionData: + human = { + "session_start": { + "page_view": 0.80, + "view_item_page": 0.15, + "session_end": 0.05, + }, + "page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20}, + "view_item_page": { + "learn_more_about_item": 0.40, + "add_item_to_cart": 0.28, + "session_end": 0.32, + }, + "learn_more_about_item": { + "add_item_to_cart": 0.50, + "view_item_page": 0.30, + "session_end": 0.20, + }, + "add_item_to_cart": { + "checkout_start": 0.58, + "view_item_page": 0.24, + "session_end": 0.18, + }, + "checkout_start": {"purchase_complete": 0.70, "session_end": 0.30}, + "purchase_complete": {"session_end": 1.0}, + } + agent = { + "session_start": { + "page_view": 0.90, + "view_item_page": 0.08, + "session_end": 0.02, + }, + "page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25}, + "view_item_page": { + "learn_more_about_item": 0.55, + "add_item_to_cart": 0.15, + "session_end": 0.30, + }, + "learn_more_about_item": { + "view_item_page": 0.45, + "add_item_to_cart": 0.20, + "session_end": 0.35, + }, + "add_item_to_cart": { + "checkout_start": 0.42, + "view_item_page": 0.28, + "session_end": 0.30, + }, + "checkout_start": {"purchase_complete": 0.52, "session_end": 0.48}, + "purchase_complete": {"session_end": 1.0}, + } + return compile_transition_data(human, agent) + + +def load_transition_data(prefer_data: bool = True) -> TransitionData: + if not prefer_data: + return fallback_transition_data() + try: + from ..lib.behavior import get_transition_models + + human_trans, agent_trans = get_transition_models() + return compile_transition_data(human_trans, agent_trans) + except Exception: + return fallback_transition_data() + + +if JAX_AVAILABLE: + + @partial(jax.jit, static_argnums=(8, 9, 10)) + def _sample_sessions_jax( + key: jax.Array, + human_T: jax.Array, + agent_T: jax.Array, + terminal_mask: jax.Array, + start_idx: int, + term_idx: int, + alpha: float, + n_products: int, + n_sessions: int, + max_steps: int, + n_states: int, + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + k_actor, k_product, k_step = jax.random.split(key, 3) + actor_draw = jax.random.uniform(k_actor, (n_sessions,)) + actors = (actor_draw < alpha).astype(jnp.int32) + products = jax.random.randint( + k_product, (n_sessions,), 0, n_products, dtype=jnp.int32 + ) + + active_init = jnp.ones((n_sessions,), dtype=jnp.bool_) + state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32) + + def _scan_step(carry, _): + states, active, rng = carry + rng, k = jax.random.split(rng) + probs_h = human_T[states] + probs_a = agent_T[states] + probs = jnp.where(actors[:, None] == 0, probs_h, probs_a) + next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1) + next_state = jnp.where(active, next_state, int(term_idx)) + emitted = jnp.where(active, next_state, -1) + is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)] + next_active = active & (~is_terminal) + carry_states = jnp.where(next_active, next_state, int(term_idx)) + return (carry_states, next_active, rng), emitted + + _, state_t = jax.lax.scan( + _scan_step, (state_init, active_init, k_step), None, length=max_steps + ) + states = state_t.T + lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32) + return states, products, actors, lengths + + +def sample_sessions( + key, + transition_data: TransitionData, + alpha: float, + n_products: int, + n_sessions: int, + max_steps: int, +) -> SessionBatch: + if JAX_AVAILABLE: + td = transition_data.to_jax() + states, products, actors, lengths = _sample_sessions_jax( + key, + td.human_T, + td.agent_T, + td.terminal_mask, + int(td.start_idx), + int(td.term_idx), + float(alpha), + int(n_products), + int(n_sessions), + int(max_steps), + int(td.human_T.shape[0]), + ) + return SessionBatch( + states=states, products=products, actors=actors, lengths=lengths + ) + + rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0])) + n_states = transition_data.human_T.shape[0] + products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32) + actors = (rng.random(size=n_sessions) < alpha).astype(np.int32) + states = np.full((n_sessions, max_steps), -1, dtype=np.int32) + lengths = np.zeros((n_sessions,), dtype=np.int32) + for i in range(n_sessions): + current = int(transition_data.start_idx) + mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T + for t in range(max_steps): + nxt = int(rng.choice(n_states, p=mat[current])) + states[i, t] = nxt + if transition_data.terminal_mask[nxt]: + lengths[i] = t + 1 + break + current = nxt + if lengths[i] == 0: + lengths[i] = max_steps + return SessionBatch( + states=states, products=products, actors=actors, lengths=lengths + ) + + +if JAX_AVAILABLE: + + @partial(jax.jit, static_argnums=(2,)) + def compute_session_transitions(states, lengths, n_states: int): + src = states[:, :-1] + dst = states[:, 1:] + time_idx = jnp.arange(src.shape[1])[None, :] + valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1)) + src_clip = jnp.clip(src, 0, n_states - 1) + dst_clip = jnp.clip(dst, 0, n_states - 1) + src_oh = jax.nn.one_hot(src_clip, n_states) + dst_oh = jax.nn.one_hot(dst_clip, n_states) + counts = jnp.einsum( + "nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32) + ) + row_sums = jnp.sum(counts, axis=-1, keepdims=True) + return counts / (row_sums + 1e-10) + + +else: + + def compute_session_transitions(states, lengths, n_states: int): + trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32) + for i in range(states.shape[0]): + for t in range(max(int(lengths[i]) - 1, 0)): + s = int(states[i, t]) + d = int(states[i, t + 1]) + if s >= 0 and d >= 0: + trans[i, s, d] += 1.0 + row_sums = trans.sum(axis=-1, keepdims=True) + return trans / (row_sums + 1e-10) + + +def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10): + p = P + eps + p = p / jnp.sum(p, axis=-1, keepdims=True) + qh = Q_human[None, ...] + eps + qa = Q_agent[None, ...] + eps + delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2)) + delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2)) + return delta_h, delta_a + + +if JAX_AVAILABLE: + batch_kl = jax.jit(batch_kl) + + +def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0): + t = jnp.maximum(float(temperature), 1e-6) + exp_h = jnp.exp(-delta_h / t) + exp_a = jnp.exp(-delta_a / t) + return exp_a / (exp_h + exp_a + 1e-10) + + +def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0): + logits = beta * (delta_h - delta_a) + return 1.0 / (1.0 + jnp.exp(-logits)) + + +def weighted_demand(states, products, n_products: int, event_weights): + valid = states >= 0 + state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1) + weights = event_weights[state_clip] * valid + per_session = jnp.sum(weights, axis=1) + demand = jnp.zeros((n_products,), dtype=jnp.float32) + demand = demand.at[products].add(per_session) + total = jnp.sum(demand) + return jnp.where(total > 0.0, (demand / total) * 100.0, demand) + + +if JAX_AVAILABLE: + weighted_demand = jax.jit(weighted_demand, static_argnums=(2,)) + + +def purchase_flags(states, purchase_mask): + state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1) + hits = purchase_mask[state_clip] & (states >= 0) + return jnp.any(hits, axis=1) + + +if JAX_AVAILABLE: + purchase_flags = jax.jit(purchase_flags) + + +def revenue_from_demand(prices, demand): + return jnp.dot(prices, demand) + + +if JAX_AVAILABLE: + revenue_from_demand = jax.jit(revenue_from_demand) + + +def reward_with_coi_penalty( + revenue, agent_prob: float, lambda_coi: float, info_value: float +): + leakage = agent_prob * info_value + discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0) + return revenue * discount, leakage, discount + + +if JAX_AVAILABLE: + reward_with_coi_penalty = jax.jit(reward_with_coi_penalty) diff --git a/engine/jax/requirements.txt b/engine/jax/requirements.txt new file mode 100644 index 0000000..42ba457 --- /dev/null +++ b/engine/jax/requirements.txt @@ -0,0 +1,5 @@ +flax>=0.8.0 +optax>=0.2.0 +distrax>=0.1.5 +orbax-checkpoint>=0.5.0 +chex>=0.1.8 diff --git a/engine/jax/train.py b/engine/jax/train.py new file mode 100644 index 0000000..f2f4168 --- /dev/null +++ b/engine/jax/train.py @@ -0,0 +1,471 @@ +"""Pure JAX PPO trainer for the PHANTOM environment.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, NamedTuple + +import numpy as np + +try: + import jax + import jax.numpy as jnp + import distrax + import flax.linen as nn + import optax + from flax import serialization + from flax.linen.initializers import constant, orthogonal + from flax.training.train_state import TrainState + + HAS_JAX_STACK = True +except ImportError: + jax = None # type: ignore[assignment] + jnp = None # type: ignore[assignment] + distrax = None # type: ignore[assignment] + optax = None # type: ignore[assignment] + serialization = None # type: ignore[assignment] + + class _ModuleStub: + pass + + class _NNStub: + Module = _ModuleStub + + @staticmethod + def compact(fn): + return fn + + nn = _NNStub() # type: ignore[assignment] + + def constant(*_args, **_kwargs): # type: ignore[override] + return None + + def orthogonal(*_args, **_kwargs): # type: ignore[override] + return None + + class TrainState: # type: ignore[override] + pass + + HAS_JAX_STACK = False + +from .env import PHANTOMJAXEnv, make_env_params + + +class ActorCritic(nn.Module): + action_dim: int + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + activation_fn = nn.relu if self.activation == "relu" else nn.tanh + + actor = nn.Dense( + 64, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(x) + actor = activation_fn(actor) + actor = nn.Dense( + 64, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(actor) + actor = activation_fn(actor) + logits = nn.Dense( + self.action_dim, + kernel_init=orthogonal(0.01), + bias_init=constant(0.0), + )(actor) + + critic = nn.Dense( + 64, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(x) + critic = activation_fn(critic) + critic = nn.Dense( + 64, + kernel_init=orthogonal(np.sqrt(2.0)), + bias_init=constant(0.0), + )(critic) + critic = activation_fn(critic) + value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1) + + +class Transition(NamedTuple): + done: jax.Array + action: jax.Array + value: jax.Array + reward: jax.Array + log_prob: jax.Array + obs: jax.Array + info: dict[str, jax.Array] + + +def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]: + out = { + "algo": str(cfg.get("algo", "ppo")).lower(), + "seed": int(cfg.get("seed", 42)), + "learning_rate": float(cfg.get("learning_rate", 3e-4)), + "gamma": float(cfg.get("gamma", 0.99)), + "gae_lambda": float(cfg.get("gae_lambda", 0.95)), + "clip_range": float(cfg.get("clip_range", 0.2)), + "ent_coef": float(cfg.get("ent_coef", 0.01)), + "vf_coef": float(cfg.get("vf_coef", 0.5)), + "max_grad_norm": float(cfg.get("max_grad_norm", 0.5)), + "activation": str(cfg.get("activation", "relu")), + "total_timesteps": int(cfg.get("total_timesteps", 50_000)), + "eval_episodes": int(cfg.get("eval_episodes", 5)), + "model_dir": str(cfg.get("model_dir", "engine/models")), + "n_products": int(cfg.get("n_products", 10)), + "N": int(cfg.get("N", 100)), + "alpha": float(cfg.get("alpha", 0.3)), + "lambda_coi": float(cfg.get("lambda_coi", 0.2)), + "robust_radius": float(cfg.get("robust_radius", 0.15)), + "robust_points": int(cfg.get("robust_points", 5)), + "info_value": float(cfg.get("info_value", 1.0)), + "price_low": float(cfg.get("price_low", 10.0)), + "price_high": float(cfg.get("price_high", 150.0)), + "action_levels": int(cfg.get("action_levels", 9)), + "action_scale_low": float(cfg.get("action_scale_low", 0.8)), + "action_scale_high": float(cfg.get("action_scale_high", 1.2)), + "max_episode_steps": int(cfg.get("max_steps", 100)), + "max_session_steps": int(cfg.get("max_session_steps", 40)), + "margin_floor": float(cfg.get("margin_floor", 0.05)), + "margin_floor_patience": int(cfg.get("margin_floor_patience", 5)), + "prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)), + "num_envs": int(cfg.get("jax_num_envs", 16)), + "num_steps": int(cfg.get("jax_num_steps", 128)), + "num_minibatches": int(cfg.get("jax_num_minibatches", 4)), + "update_epochs": int(cfg.get("jax_update_epochs", 4)), + "anneal_lr": bool(cfg.get("jax_anneal_lr", True)), + } + rollout = out["num_envs"] * out["num_steps"] + out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1)) + out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1)) + return out + + +def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array: + mask = done + while mask.ndim < keep.ndim: + mask = mask[..., None] + return jnp.where(mask, reset, keep) + + +def make_train(config: dict[str, Any]): + cfg = _jax_cfg(config) + env_params = make_env_params( + n_products=cfg["n_products"], + alpha=cfg["alpha"], + n_sessions=cfg["N"], + lambda_coi=cfg["lambda_coi"], + robust_radius=cfg["robust_radius"], + robust_points=cfg["robust_points"], + info_value=cfg["info_value"], + action_levels=cfg["action_levels"], + action_scale_low=cfg["action_scale_low"], + action_scale_high=cfg["action_scale_high"], + price_low=cfg["price_low"], + price_high=cfg["price_high"], + max_episode_steps=cfg["max_episode_steps"], + max_session_steps=cfg["max_session_steps"], + margin_floor=cfg["margin_floor"], + margin_floor_patience=cfg["margin_floor_patience"], + prefer_behavior_data=cfg["prefer_behavior_data"], + ) + env = PHANTOMJAXEnv(env_params) + network = ActorCritic(env.action_space_n(), activation=cfg["activation"]) + + def linear_schedule(count: jax.Array) -> jax.Array: + updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"]) + frac = 1.0 - updates_done / max(cfg["num_updates"], 1) + return cfg["learning_rate"] * frac + + def train(rng: jax.Array): + rng, init_key = jax.random.split(rng) + init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32) + params = network.init(init_key, init_obs) + + if cfg["anneal_lr"]: + tx = optax.chain( + optax.clip_by_global_norm(cfg["max_grad_norm"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(cfg["max_grad_norm"]), + optax.adam(cfg["learning_rate"], eps=1e-5), + ) + train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx) + + rng, reset_key = jax.random.split(rng) + reset_keys = jax.random.split(reset_key, cfg["num_envs"]) + obs, env_state = jax.vmap(env.reset)(reset_keys) + + def _update_step(runner_state, _): + def _env_step(runner_state, _): + train_state, env_state, last_obs, rng = runner_state + rng, action_key = jax.random.split(rng) + policy, value = network.apply(train_state.params, last_obs) + action = policy.sample(seed=action_key) + log_prob = policy.log_prob(action) + + rng, step_key = jax.random.split(rng) + step_keys = jax.random.split(step_key, cfg["num_envs"]) + nxt_obs, nxt_state, reward, done, info = jax.vmap( + env.step, + in_axes=(0, 0, 0), + )(step_keys, env_state, action) + + rng, reset_key = jax.random.split(rng) + reset_keys = jax.random.split(reset_key, cfg["num_envs"]) + rst_obs, rst_state = jax.vmap(env.reset)(reset_keys) + obs_next = jnp.where(done[:, None], rst_obs, nxt_obs) + env_next = jax.tree_util.tree_map( + lambda keep, reset: _select_env_state(done, keep, reset), + nxt_state, + rst_state, + ) + transition = Transition( + done=done, + action=action, + value=value, + reward=reward, + log_prob=log_prob, + obs=last_obs, + info=info, + ) + return (train_state, env_next, obs_next, rng), transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, + runner_state, + None, + length=cfg["num_steps"], + ) + + train_state, env_state, last_obs, rng = runner_state + _, last_value = network.apply(train_state.params, last_obs) + + def _compute_gae(traj_batch, last_value): + def _gae_step(carry, transition): + gae, next_value = carry + delta = ( + transition.reward + + cfg["gamma"] * next_value * (1.0 - transition.done) + - transition.value + ) + gae = ( + delta + + cfg["gamma"] + * cfg["gae_lambda"] + * (1.0 - transition.done) + * gae + ) + return (gae, transition.value), gae + + _, advantages = jax.lax.scan( + _gae_step, + (jnp.zeros_like(last_value), last_value), + traj_batch, + reverse=True, + unroll=16, + ) + targets = advantages + traj_batch.value + return advantages, targets + + advantages, targets = _compute_gae(traj_batch, last_value) + + def _update_epoch(update_state, _): + def _update_minibatch(train_state, batch_info): + traj_b, adv_b, tgt_b = batch_info + + def _loss_fn(params, traj_b, adv_b, tgt_b): + policy, value = network.apply(params, traj_b.obs) + log_prob = policy.log_prob(traj_b.action) + + value_clipped = traj_b.value + (value - traj_b.value).clip( + -cfg["clip_range"], cfg["clip_range"] + ) + value_loss = ( + 0.5 + * jnp.maximum( + jnp.square(value - tgt_b), + jnp.square(value_clipped - tgt_b), + ).mean() + ) + + adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8) + ratio = jnp.exp(log_prob - traj_b.log_prob) + loss_actor = -jnp.minimum( + ratio * adv_norm, + jnp.clip( + ratio, + 1.0 - cfg["clip_range"], + 1.0 + cfg["clip_range"], + ) + * adv_norm, + ).mean() + entropy = policy.entropy().mean() + total_loss = ( + loss_actor + + cfg["vf_coef"] * value_loss + - cfg["ent_coef"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b) + train_state = train_state.apply_gradients(grads=grads) + return train_state, jnp.asarray(0.0, dtype=jnp.float32) + + train_state, traj_batch, advantages, targets, rng = update_state + rng, perm_key = jax.random.split(rng) + batch_size = cfg["num_envs"] * cfg["num_steps"] + permutation = jax.random.permutation(perm_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), + batch, + ) + shuffled = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), + batch, + ) + minibatches = jax.tree_util.tree_map( + lambda x: x.reshape( + (cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:] + ), + shuffled, + ) + train_state, _ = jax.lax.scan( + _update_minibatch, train_state, minibatches + ) + return (train_state, traj_batch, advantages, targets, rng), None + + update_state = (train_state, traj_batch, advantages, targets, rng) + update_state, _ = jax.lax.scan( + _update_epoch, + update_state, + None, + length=cfg["update_epochs"], + ) + train_state = update_state[0] + rng = update_state[-1] + + metric = { + "reward": jnp.mean(traj_batch.reward), + "revenue": jnp.mean(traj_batch.info["revenue"]), + "agent_prob": jnp.mean(traj_batch.info["agent_prob"]), + "alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]), + "coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]), + } + runner_state = (train_state, env_state, last_obs, rng) + return runner_state, metric + + runner_state = (train_state, env_state, obs, rng) + runner_state, metric = jax.lax.scan( + _update_step, + runner_state, + None, + length=cfg["num_updates"], + ) + return { + "runner_state": runner_state, + "metrics": metric, + } + + return train, network, env, cfg + + +def evaluate_policy( + *, + network: ActorCritic, + params: Any, + env: PHANTOMJAXEnv, + episodes: int, + seed: int, +) -> dict[str, float]: + rewards: list[float] = [] + revenues: list[float] = [] + key = jax.random.PRNGKey(seed) + + for _ in range(int(episodes)): + key, reset_key = jax.random.split(key) + obs, state = env.reset(reset_key) + ep_reward = 0.0 + ep_revenue = 0.0 + done = False + steps = 0 + + while not done and steps < int(env.params.max_episode_steps): + policy, _ = network.apply(params, obs) + action = jnp.argmax(policy.logits) + key, step_key = jax.random.split(key) + obs, state, reward, done_flag, info = env.step(step_key, state, action) + ep_reward += float(np.asarray(reward)) + ep_revenue += float(np.asarray(info["revenue"])) + done = bool(np.asarray(done_flag)) + steps += 1 + + rewards.append(ep_reward) + revenues.append(ep_revenue) + + return { + "eval/reward": float(np.mean(rewards)), + "eval/revenue": float(np.mean(revenues)), + "eval/reward_std": float(np.std(rewards)), + "eval/revenue_std": float(np.std(revenues)), + } + + +def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]: + if not HAS_JAX_STACK: + raise ImportError( + "JAX PPO path requires jax, flax, optax, and distrax. " + "Install engine/jax/requirements.txt on this machine first." + ) + + run_cfg = _jax_cfg(cfg) + if run_cfg["algo"] != "ppo": + raise ValueError( + f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'" + ) + + train_fn, network, env, run_cfg = make_train(run_cfg) + train_jit = jax.jit(train_fn) + rng = jax.random.PRNGKey(run_cfg["seed"]) + out = train_jit(rng) + + train_state = out["runner_state"][0] + metric = out["metrics"] + metrics = { + "train/reward": float(np.mean(np.asarray(metric["reward"]))), + "train/revenue": float(np.mean(np.asarray(metric["revenue"]))), + "train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))), + "train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))), + "train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))), + "train/global_step": int( + run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"] + ), + } + + eval_metrics = evaluate_policy( + network=network, + params=train_state.params, + env=env, + episodes=run_cfg["eval_episodes"], + seed=run_cfg["seed"] + 7, + ) + metrics.update(eval_metrics) + + model_dir = Path(run_cfg["model_dir"]) + model_dir.mkdir(parents=True, exist_ok=True) + model_path = model_dir / "phantom_ppo_jax.msgpack" + model_path.write_bytes(serialization.to_bytes(train_state.params)) + metrics["model/path"] = str(model_path) + return {"params": train_state.params}, metrics diff --git a/engine/lib/callbacks.py b/engine/lib/callbacks.py new file mode 100644 index 0000000..9e16d4b --- /dev/null +++ b/engine/lib/callbacks.py @@ -0,0 +1,119 @@ +"""Training callbacks for W&B/TensorBoard logging - reads from info dict.""" + +from stable_baselines3.common.callbacks import BaseCallback, EvalCallback +import numpy as np + +try: + import wandb + + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + + +class MetricsCallback(BaseCallback): + """Training metrics logger - reads info['economics'], logs to W&B.""" + + def __init__( + self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0 + ): + super().__init__(verbose) + self.log_histograms = log_histograms + self.log_freq = log_freq + self._episode_revenues: list[float] = [] + + def _on_step(self) -> bool: + if not HAS_WANDB or wandb.run is None: + return True + + for info in self.locals.get("infos", []): + if "economics" not in info: + continue + + econ = info["economics"] + t = self.num_timesteps + + payload = { + "economics/revenue": econ["revenue"], + "economics/margin": econ["margin"], + "coi/level": econ["coi_level"], + "economics/regret": econ["regret"], + } + if "coi_mix" in econ: + payload["coi/mix"] = econ["coi_mix"] + if "coi_base" in econ: + payload["coi/base"] = econ["coi_base"] + if "coi_leakage" in econ: + payload["coi/leakage"] = econ["coi_leakage"] + if "coi_penalty" in econ: + payload["coi/penalty"] = econ["coi_penalty"] + wandb.log(payload, step=t) + + self._episode_revenues.append(econ["revenue"]) + + # histograms at log_freq intervals + if self.log_histograms and self.num_timesteps % self.log_freq == 0: + for info in self.locals.get("infos", []): + if "prices" in info: + wandb.log( + {"distributions/prices": wandb.Histogram(info["prices"])}, + step=self.num_timesteps, + ) + if "demand" in info: + wandb.log( + {"distributions/demand": wandb.Histogram(info["demand"])}, + step=self.num_timesteps, + ) + + return True + + def _on_rollout_end(self) -> None: + if not HAS_WANDB or wandb.run is None or not self._episode_revenues: + return + wandb.log( + { + "episode/mean_revenue": np.mean(self._episode_revenues), + "episode/total_revenue": np.sum(self._episode_revenues), + }, + step=self.num_timesteps, + ) + self._episode_revenues = [] + + +class EvalMetricsCallback(EvalCallback): + """Deterministic evaluation - true performance without exploration noise.""" + + def __init__( + self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs + ): + super().__init__( + eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs + ) + self._eval_revenues: list[float] = [] + + def _on_step(self) -> bool: + result = super()._on_step() + + if not HAS_WANDB or wandb.run is None: + return result + + # log eval metrics after evaluation runs + if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"): + wandb.log( + { + "eval/mean_reward": self.last_mean_reward, + "eval/mean_revenue": np.mean(self._eval_revenues) + if self._eval_revenues + else 0, + }, + step=self.num_timesteps, + ) + self._eval_revenues = [] + + return result + + def _log_success_callback(self, locals_: dict, globals_: dict) -> None: + # called after each eval episode + info = locals_.get("info", {}) + if "economics" in info: + self._eval_revenues.append(info["economics"]["revenue"]) diff --git a/engine/lib/coi.py b/engine/lib/coi.py new file mode 100644 index 0000000..33267b5 --- /dev/null +++ b/engine/lib/coi.py @@ -0,0 +1,76 @@ +import numpy as np +from typing import Dict + + +def compute_agent_probability( + trajectory: list, human_transitions: Dict, agent_transitions: Dict +) -> float: + """estimate agent probability via KL divergence between trajectory transitions and reference models + + compares empirical trajectory transition distribution to human/agent prototypes + + args: + trajectory: list of state/event strings from session + human_transitions: reference transition dict from human MDP (event->event->prob) + agent_transitions: reference transition dict from agent MDP (event->event->prob) + + returns: + agent probability in [0, 1] via softmax over KL divergences + """ + if len(trajectory) < 2: + return 0.0 # insufficient data, assume human + + # build empirical transition distribution from trajectory + trans_counts = {} + for s, s_next in zip(trajectory[:-1], trajectory[1:]): + if s not in trans_counts: + trans_counts[s] = {} + trans_counts[s][s_next] = trans_counts[s].get(s_next, 0) + 1 + + # normalize to probabilities + empirical = {} + for s, nxt in trans_counts.items(): + total = sum(nxt.values()) + empirical[s] = {s_n: cnt / total for s_n, cnt in nxt.items()} + + # compute KL divergence to each prototype + def kl_div(p_dist: Dict, q_dist: Dict) -> float: + eps = 1e-10 + # aggregate over all source states in empirical dist + kl = 0.0 + for s in p_dist: + if s not in q_dist: + continue # skip states not in reference + p_trans, q_trans = p_dist[s], q_dist[s] + for k in p_trans: + p_val = p_trans[k] + eps + q_val = q_trans.get(k, 0.0) + eps + kl += p_val * np.log(p_val / q_val) + return kl + + kl_human = kl_div(empirical, human_transitions) + kl_agent = kl_div(empirical, agent_transitions) + + # convert to probability via softmax (lower KL = higher prob) + # agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent)) + exp_h = np.exp(-kl_human) + exp_a = np.exp(-kl_agent) + return float(exp_a / (exp_h + exp_a + 1e-10)) + + +def extract_purchases(trajectories: list) -> Dict[int, int]: + purchases: Dict[int, int] = {} + for traj in trajectories: + if traj and "checkout" in traj[-1] and "_product" in traj[-1]: + prod_id = int(traj[-1].rsplit("_product", 1)[1]) + purchases[prod_id] = purchases.get(prod_id, 0) + 1 + return purchases + + +def compute_uplift_coi( + prices: np.ndarray, purchases: Dict[int, int], baseline_prices: np.ndarray +) -> float: + # TODO: consider view-weighted fractional purchase for denser signal + return float( + sum(max(0.0, prices[k] - baseline_prices[k]) * n for k, n in purchases.items()) + ) diff --git a/engine/lib/discrete.py b/engine/lib/discrete.py new file mode 100644 index 0000000..9cee3ad --- /dev/null +++ b/engine/lib/discrete.py @@ -0,0 +1,70 @@ +from collections import defaultdict +import gymnasium as gym +from gymnasium import spaces +import numpy as np + + +class DiscretePriceActionWrapper(gym.ActionWrapper): + def __init__( + self, + env: gym.Env, + n_levels: int = 9, + min_scale: float = 0.8, + max_scale: float = 1.2, + ): + super().__init__(env) + self.scales = np.linspace(min_scale, max_scale, n_levels, dtype=np.float32) + self.action_space = spaces.Discrete(n_levels) + + def action(self, action: int): + scale = float(self.scales[int(action)]) + cur = np.asarray(self.env.unwrapped._prices, dtype=np.float32) + lo, hi = self.env.unwrapped.price_bounds + return np.clip(cur * scale, lo, hi).astype(np.float32) + + +class EventQTable: + def __init__( + self, + n_actions: int, + n_products: int, + price_bounds: tuple, + lr: float = 0.1, + gamma: float = 0.99, + n_bins: int = 6, + ): + self.n_actions = int(n_actions) + self.n_products = int(n_products) + self.lr = float(lr) + self.gamma = float(gamma) + self.q = defaultdict(lambda: np.zeros(self.n_actions, dtype=np.float32)) + lo, hi = price_bounds + self.demand_bins = np.linspace(0.0, 100.0, n_bins + 1)[1:-1] + self.price_bins = np.linspace(lo, hi, n_bins + 1)[1:-1] + + def encode(self, obs: np.ndarray) -> tuple: + obs = np.asarray(obs, dtype=np.float32) + d = obs[: self.n_products] + p = obs[self.n_products : 2 * self.n_products] + d_mean = float(np.mean(d)) if d.size else 0.0 + d_std = float(np.std(d)) if d.size else 0.0 + p_mean = float(np.mean(p)) if p.size else 0.0 + return ( + int(np.digitize(d_mean, self.demand_bins)), + int(np.digitize(d_std, self.demand_bins)), + int(np.digitize(p_mean, self.price_bins)), + ) + + def act(self, obs: np.ndarray, eps: float = 0.0) -> tuple[int, tuple]: + s = self.encode(obs) + if np.random.random() < eps: + return int(np.random.randint(self.n_actions)), s + return int(np.argmax(self.q[s])), s + + def update(self, s: tuple, a: int, r: float, s2: tuple, done: bool): + target = r + (0.0 if done else self.gamma * float(np.max(self.q[s2]))) + self.q[s][a] += self.lr * (target - self.q[s][a]) + + def predict(self, obs: np.ndarray, deterministic: bool = True): + a, _ = self.act(obs, 0.0 if deterministic else 0.05) + return a, None diff --git a/engine/lib/providers.py b/engine/lib/providers.py new file mode 100644 index 0000000..19d2788 --- /dev/null +++ b/engine/lib/providers.py @@ -0,0 +1,182 @@ +"""Provider benchmarking - compare pricing strategies across contamination levels.""" + +from dataclasses import dataclass, field +from typing import Callable, Any +import numpy as np +import pandas as pd + +try: + import wandb + + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + + +class RandomBaseline: + """uniform random action selection as a lower-bound baseline""" + + def __init__(self, n_actions: int): + self.n = n_actions + + def __call__(self, obs): + return int(np.random.randint(self.n)) + + def predict(self, obs, **kw): + return self(obs), None + + +class SurgeBaseline: + """heuristic surge pricing: boost price when demand is above threshold, discount when below. + matches the naive pricing rule from thesis Section 3.3.2""" + + def __init__( + self, n_actions: int, high_threshold: float = 60.0, low_threshold: float = 30.0 + ): + self.n = n_actions + self.mid = n_actions // 2 # identity action (scale ~1.0) + self.high_t = high_threshold + self.low_t = low_threshold + + def __call__(self, obs): + obs = np.asarray(obs, dtype=np.float32) + n_prod = len(obs) // 2 + demand_mean = float(np.mean(obs[:n_prod])) if n_prod > 0 else 0.0 + if demand_mean >= self.high_t: + return min(self.mid + 2, self.n - 1) # surge: two levels above identity + if demand_mean <= self.low_t: + return max(self.mid - 2, 0) # discount: two levels below identity + return self.mid # hold + + def predict(self, obs, **kw): + return self(obs), None + + +@dataclass +class ProviderResult: + """Single benchmark result for one provider at one alpha level.""" + + name: str + alpha: float + total_revenue: float + mean_revenue: float + coi_level: float + coi_preserved_pct: float # vs alpha=0 baseline + margin_integrity: float + regret: float + episodes: int + + +@dataclass +class BenchmarkConfig: + """Configuration for provider benchmark runs.""" + + n_episodes: int = 100 + alpha_range: list[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5]) + baseline_name: str = "fixed" + + +class ProviderBenchmark: + """Compare pricing providers to prove margin preservation across contamination levels. + + Usage: + def env_factory(alpha): + return EconomicMetricsWrapper(PHANTOM(alpha=alpha)) + + providers = { + "fixed": lambda obs: np.ones(10) * 50, + "learned": model.predict, + } + + benchmark = ProviderBenchmark(env_factory, providers) + results = benchmark.run() + print(benchmark.summary_table()) + """ + + def __init__( + self, + env_factory: Callable[[float], Any], + providers: dict[str, Callable], + config: BenchmarkConfig | None = None, + ): + self.env_factory = env_factory # fn(alpha) -> wrapped env + self.providers = providers # {name: fn(obs) -> action} + self.config = config or BenchmarkConfig() + self.results: list[ProviderResult] = [] + + def run(self) -> list[ProviderResult]: + """Run benchmark across all providers and alpha levels.""" + baseline_coi: dict[str, float] = {} # {provider: coi at alpha=0} + + for alpha in self.config.alpha_range: + env = self.env_factory(alpha) + + for name, policy_fn in self.providers.items(): + revenues, coi_levels, margins = [], [], [] + + for _ in range(self.config.n_episodes): + obs, _ = env.reset() + episode_revenue = 0.0 + done = False + + while not done: + action = policy_fn(obs) + # handle sb3 model.predict returning tuple + if isinstance(action, tuple): + action = action[0] + obs, reward, term, trunc, info = env.step(action) + done = term or trunc + + econ = info.get("economics", {}) + episode_revenue += econ.get("revenue", 0) + coi_levels.append(econ.get("coi_level", 0)) + margins.append(econ.get("margin", 0)) + + revenues.append(episode_revenue) + + mean_coi = np.mean(coi_levels) if coi_levels else 0.0 + if alpha == 0.0: + baseline_coi[name] = mean_coi + + base = baseline_coi.get(name, mean_coi) + coi_preserved = mean_coi / base if base > 0 else 1.0 + + result = ProviderResult( + name=name, + alpha=alpha, + total_revenue=float(np.sum(revenues)), + mean_revenue=float(np.mean(revenues)), + coi_level=mean_coi, + coi_preserved_pct=coi_preserved * 100, + margin_integrity=float(np.mean(margins)) if margins else 0.0, + regret=0.0, # compute vs optimal if known + episodes=self.config.n_episodes, + ) + self.results.append(result) + + # log to wandb if available + if HAS_WANDB and wandb.run is not None: + wandb.log( + { + f"benchmark/{name}/revenue": result.mean_revenue, + f"benchmark/{name}/coi_preserved": result.coi_preserved_pct, + f"benchmark/{name}/margin": result.margin_integrity, + "benchmark/alpha": alpha, + } + ) + + return self.results + + def to_dataframe(self) -> pd.DataFrame: + """Convert results to pandas DataFrame.""" + return pd.DataFrame([r.__dict__ for r in self.results]) + + def summary_table(self) -> pd.DataFrame: + """Pivot table: providers x alpha with revenue/COI metrics.""" + df = self.to_dataframe() + return df.pivot_table( + index="name", + columns="alpha", + values=["mean_revenue", "coi_preserved_pct", "margin_integrity"], + aggfunc="mean", + ) diff --git a/engine/lib/wrappers.py b/engine/lib/wrappers.py new file mode 100644 index 0000000..3d74b79 --- /dev/null +++ b/engine/lib/wrappers.py @@ -0,0 +1,77 @@ +"""Economic metrics wrapper - calculates thesis-aligned KPIs and injects into info dict.""" + +import gymnasium as gym +import numpy as np + + +class EconomicMetricsWrapper(gym.Wrapper): + """Calculates thesis-aligned economic metrics per step, injects into info. + + Metrics follow thesis definitions: + - COI level: E[P] - p_min (Definition 1) + - Margin: (avg_price - p_min) / avg_price + - Regret: 1 - (revenue / baseline_revenue) + """ + + def __init__( + self, env: gym.Env, p_min: float = 10.0, baseline_revenue: float | None = None + ): + super().__init__(env) + self.p_min = p_min + self.baseline_revenue = baseline_revenue + self._price_history: list[np.ndarray] = [] + self._revenue_history: list[float] = [] + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + self._price_history = [] + self._revenue_history = [] + return obs, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + + # extract from unwrapped env + prices = self.env.unwrapped._prices + demand_dict = self.env.unwrapped._demand + demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))]) + alpha = self.env.unwrapped.alpha + + # core calculations + revenue = float(np.sum(prices * demand)) + avg_price = float(np.mean(prices)) + margin = (avg_price - self.p_min) / max(avg_price, 1e-6) + coi_level = avg_price - self.p_min # E[P] - p_min per thesis Def 1 + + self._price_history.append(prices.copy()) + self._revenue_history.append(revenue) + + # regret vs baseline (golden path) + regret = 0.0 + if self.baseline_revenue and self.baseline_revenue > 0: + regret = 1.0 - (revenue / self.baseline_revenue) + + # inject structured metrics into info + info["economics"] = { + "revenue": revenue, + "margin": margin, + "coi_level": coi_level, + "regret": regret, + } + for key in ("coi_mix", "coi_base", "coi_leakage", "coi_penalty"): + if key in info: + info["economics"][key] = info[key] + info["prices"] = prices.copy() + info["demand"] = demand.copy() + + return obs, reward, terminated, truncated, info + + @property + def episode_revenue(self) -> float: + return sum(self._revenue_history) + + @property + def episode_mean_price(self) -> float: + if not self._price_history: + return 0.0 + return float(np.mean([np.mean(p) for p in self._price_history])) diff --git a/engine/sweeps/model_mix.yaml b/engine/sweeps/model_mix.yaml new file mode 100644 index 0000000..28a7f38 --- /dev/null +++ b/engine/sweeps/model_mix.yaml @@ -0,0 +1,84 @@ +method: random +metric: + name: sweep/score + goal: maximize +command: + - ${env} + - python + - -m + - engine.train +parameters: + algo: + values: [ppo, a2c, dqn, qtable] + total_timesteps: + values: [30000, 50000, 80000] + seed: + values: [13, 42, 77] + n_products: + values: [8, 10, 12] + alpha: + distribution: uniform + min: 0.1 + max: 0.6 + lambda_coi: + distribution: uniform + min: 0.05 + max: 0.6 + robust_radius: + distribution: uniform + min: 0.0 + max: 0.3 + robust_points: + values: [3, 5, 7] + info_value: + distribution: uniform + min: 0.5 + max: 2.0 + revenue_weight: + values: [0.005, 0.01, 0.02] + learning_rate: + distribution: log_uniform_values + min: 1.0e-5 + max: 1.0e-3 + gamma: + values: [0.97, 0.99, 0.995] + buffer_size: + values: [20000, 50000, 100000] + batch_size: + values: [128, 256, 512] + tau: + values: [0.002, 0.005, 0.01] + train_freq: + values: [1, 4, 8] + learning_starts: + values: [500, 1000, 3000] + n_steps: + values: [512, 1024, 2048] + n_epochs: + values: [5, 10, 20] + gae_lambda: + values: [0.9, 0.95, 0.98] + clip_range: + values: [0.1, 0.2, 0.3] + ent_coef: + values: [0.0, 0.005, 0.01] + target_update_interval: + values: [500, 1000, 2000] + exploration_fraction: + values: [0.1, 0.2, 0.3] + exploration_final_eps: + values: [0.01, 0.03, 0.05] + action_levels: + values: [7, 9, 11] + action_scale_low: + values: [0.75, 0.8, 0.85] + action_scale_high: + values: [1.15, 1.2, 1.25] + q_lr: + values: [0.03, 0.05, 0.1, 0.2] + eps_start: + value: 1.0 + eps_end: + values: [0.02, 0.05, 0.1] + eps_decay: + values: [0.999, 0.9995, 0.9999] diff --git a/engine/sweeps/models_only.yaml b/engine/sweeps/models_only.yaml new file mode 100644 index 0000000..e0bd708 --- /dev/null +++ b/engine/sweeps/models_only.yaml @@ -0,0 +1,85 @@ +method: grid +metric: + name: sweep/score + goal: maximize +run_cap: 4 +command: + - ${env} + - python + - -m + - engine.train +parameters: + algo: + values: [ppo, a2c, dqn, qtable] + seed: + value: 42 + total_timesteps: + value: 12000 + eval_episodes: + value: 3 + eval_freq: + value: 500 + log_freq: + value: 100 + revenue_weight: + value: 0.01 + n_products: + value: 8 + N: + value: 80 + alpha: + value: 0.3 + lambda_coi: + value: 0.2 + robust_radius: + value: 0.0 + robust_points: + value: 1 + info_value: + value: 1.0 + learning_rate: + value: 0.0003 + gamma: + value: 0.99 + buffer_size: + value: 20000 + batch_size: + value: 128 + tau: + value: 0.005 + train_freq: + value: 1 + learning_starts: + value: 500 + n_steps: + value: 512 + n_epochs: + value: 10 + gae_lambda: + value: 0.95 + clip_range: + value: 0.2 + ent_coef: + value: 0.0 + target_update_interval: + value: 500 + exploration_fraction: + value: 0.2 + exploration_final_eps: + value: 0.05 + action_levels: + value: 7 + action_scale_low: + value: 0.9 + action_scale_high: + value: 1.1 + q_lr: + value: 0.1 + q_bins: + value: 6 + eps_start: + value: 1.0 + eps_end: + value: 0.05 + eps_decay: + value: 0.9995 diff --git a/engine/sweeps/sac_tune.yaml b/engine/sweeps/sac_tune.yaml new file mode 100644 index 0000000..97558cf --- /dev/null +++ b/engine/sweeps/sac_tune.yaml @@ -0,0 +1,54 @@ +method: bayes +metric: + name: sweep/score + goal: maximize +command: + - ${env} + - python + - -m + - engine.train +parameters: + algo: + value: sac + total_timesteps: + values: [50000, 80000, 120000] + seed: + values: [13, 42, 77] + alpha: + distribution: uniform + min: 0.15 + max: 0.55 + n_products: + values: [8, 10, 12] + lambda_coi: + distribution: uniform + min: 0.05 + max: 0.5 + robust_radius: + distribution: uniform + min: 0.05 + max: 0.3 + robust_points: + values: [3, 5, 7] + info_value: + distribution: uniform + min: 0.5 + max: 2.0 + revenue_weight: + values: [0.005, 0.01, 0.02] + learning_rate: + distribution: log_uniform_values + min: 3.0e-5 + max: 1.0e-3 + gamma: + values: [0.98, 0.99, 0.995] + buffer_size: + values: [50000, 100000, 200000] + batch_size: + values: [128, 256, 512] + tau: + values: [0.002, 0.005, 0.01] + train_freq: + values: [1, 4, 8] + learning_starts: + values: [1000, 3000, 5000] diff --git a/engine/sweeps/small_arch_compare.yaml b/engine/sweeps/small_arch_compare.yaml new file mode 100644 index 0000000..2eae9a0 --- /dev/null +++ b/engine/sweeps/small_arch_compare.yaml @@ -0,0 +1,86 @@ +method: random +metric: + name: sweep/score + goal: maximize +command: + - ${env} + - python + - -m + - engine.train +parameters: + algo: + values: [ppo, a2c, dqn, qtable] + arch: + values: [tiny, small, medium] + activation: + values: [relu, tanh] + total_timesteps: + values: [8000, 12000, 20000] + seed: + values: [13, 42, 77] + n_products: + values: [6, 8, 10] + alpha: + distribution: uniform + min: 0.1 + max: 0.5 + lambda_coi: + distribution: uniform + min: 0.05 + max: 0.4 + robust_radius: + values: [0.0, 0.1, 0.2] + robust_points: + values: [3, 5] + info_value: + values: [0.75, 1.0, 1.5] + revenue_weight: + values: [0.005, 0.01, 0.02] + learning_rate: + distribution: log_uniform_values + min: 1.0e-5 + max: 5.0e-4 + gamma: + values: [0.98, 0.99] + buffer_size: + values: [10000, 30000, 50000] + batch_size: + values: [64, 128, 256] + tau: + values: [0.002, 0.005, 0.01] + train_freq: + values: [1, 4] + learning_starts: + values: [500, 1000, 2000] + n_steps: + values: [256, 512, 1024] + n_epochs: + values: [5, 10] + gae_lambda: + values: [0.9, 0.95] + clip_range: + values: [0.1, 0.2] + ent_coef: + values: [0.0, 0.005] + target_update_interval: + values: [500, 1000] + exploration_fraction: + values: [0.1, 0.2] + exploration_final_eps: + values: [0.02, 0.05] + action_levels: + values: [5, 7, 9] + action_scale_low: + values: [0.85, 0.9] + action_scale_high: + values: [1.1, 1.15] + q_lr: + values: [0.05, 0.1, 0.2] + q_bins: + values: [4, 6, 8] + eps_start: + value: 1.0 + eps_end: + values: [0.02, 0.05] + eps_decay: + values: [0.999, 0.9995] diff --git a/engine/train.py b/engine/train.py index e059593..8e4eb07 100644 --- a/engine/train.py +++ b/engine/train.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import argparse import json +import os from pathlib import Path import numpy as np -from gymnasium.wrappers import FlattenObservation try: import wandb @@ -20,9 +22,7 @@ try: except ImportError: HAS_SB3 = False -from .wrapper import PHANTOM -from .lib import EconomicMetricsWrapper, MetricsCallback -from .lib.discrete import EventQTable +from .jax import JAX_AVAILABLE DEFAULT_CFG = { @@ -69,14 +69,34 @@ DEFAULT_CFG = { "arch": "small", "activation": "relu", "q_bins": 6, + "max_steps": 100, + "margin_floor": 0.05, + "margin_floor_patience": 5, + "use_jax": False, + "jax_num_envs": 16, + "jax_num_steps": 128, + "jax_num_minibatches": 4, + "jax_update_epochs": 4, + "jax_anneal_lr": True, } +def _truthy(value: str | bool | None) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + def _cfg(raw: dict | None = None) -> dict: cfg = dict(DEFAULT_CFG) if raw: cfg.update({k: v for k, v in raw.items() if v is not None}) cfg["algo"] = str(cfg["algo"]).lower() + cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy( + os.environ.get("PHANTOM_USE_JAX") + ) return cfg @@ -89,6 +109,11 @@ def _wandb_cfg_dict() -> dict: def make_env(cfg: dict): + from gymnasium.wrappers import FlattenObservation + + from .wrapper import PHANTOM + from .lib.wrappers import EconomicMetricsWrapper + env = PHANTOM( n_products=int(cfg["n_products"]), alpha=float(cfg["alpha"]), @@ -101,6 +126,9 @@ def make_env(cfg: dict): action_levels=int(cfg["action_levels"]), action_scale_low=float(cfg["action_scale_low"]), action_scale_high=float(cfg["action_scale_high"]), + max_steps=int(cfg.get("max_steps", 100)), + margin_floor=float(cfg.get("margin_floor", 0.05)), + margin_floor_patience=int(cfg.get("margin_floor_patience", 5)), render_mode=None, ) env = EconomicMetricsWrapper(env) @@ -235,6 +263,8 @@ def build_model(cfg: dict, env): def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: + from .lib.discrete import EventQTable + np.random.seed(int(cfg["seed"])) env = make_env(cfg) eval_env = make_env(cfg) @@ -275,6 +305,8 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: def train_sb3(cfg: dict) -> tuple[object, dict]: if not HAS_SB3: raise ImportError("stable-baselines3 is required for SB3 models") + from .lib.callbacks import MetricsCallback + env = make_env(cfg) eval_env = make_env(cfg) env = Monitor(env) @@ -303,7 +335,20 @@ def train_sb3(cfg: dict) -> tuple[object, dict]: def train_once(cfg: dict) -> dict: algo = cfg["algo"] - if algo == "qtable": + if cfg.get("use_jax"): + if not JAX_AVAILABLE: + raise ImportError( + "JAX backend requested but JAX is not installed. " + "Install engine/jax/requirements.txt and jax[tpu] for TPU runs." + ) + if algo == "qtable": + raise ValueError("qtable is not supported in JAX backend") + try: + from .jax.train import train_jax + except Exception as exc: # pragma: no cover + raise ImportError(f"Failed to import JAX trainer: {exc}") from exc + _, metrics = train_jax(cfg) + elif algo == "qtable": _, metrics = train_qtable(cfg) else: _, metrics = train_sb3(cfg) @@ -357,8 +402,17 @@ def main(): p.add_argument("--learning-rate", type=float) p.add_argument("--gamma", type=float) p.add_argument("--revenue-weight", type=float) + p.add_argument("--max-steps", type=int) + p.add_argument("--margin-floor", type=float) + p.add_argument("--margin-floor-patience", type=int) p.add_argument("--arch", type=str) p.add_argument("--activation", type=str) + p.add_argument("--jax", action="store_true") + p.add_argument("--jax-num-envs", type=int) + p.add_argument("--jax-num-steps", type=int) + p.add_argument("--jax-num-minibatches", type=int) + p.add_argument("--jax-update-epochs", type=int) + p.add_argument("--jax-anneal-lr", type=str) p.add_argument("--sweep-agent", action="store_true") p.add_argument("--sweep-id", type=str) p.add_argument("--count", type=int, default=0) @@ -377,8 +431,19 @@ def main(): "learning_rate": args.learning_rate, "gamma": args.gamma, "revenue_weight": args.revenue_weight, + "max_steps": args.max_steps, + "margin_floor": args.margin_floor, + "margin_floor_patience": args.margin_floor_patience, "arch": args.arch, "activation": args.activation, + "use_jax": args.jax, + "jax_num_envs": args.jax_num_envs, + "jax_num_steps": args.jax_num_steps, + "jax_num_minibatches": args.jax_num_minibatches, + "jax_update_epochs": args.jax_update_epochs, + "jax_anneal_lr": _truthy(args.jax_anneal_lr) + if args.jax_anneal_lr is not None + else None, } overrides = {k: v for k, v in overrides.items() if v is not None}