adding naive jax and libraries and make adjustments

This commit is contained in:
2026-02-17 14:48:18 +01:00
parent 66c4a0cd1d
commit 802f31b4a1
17 changed files with 2331 additions and 6 deletions

111
Makefile
View File

@@ -8,12 +8,30 @@ VENV := .venv
PYTHON := $(VENV)/bin/python PYTHON := $(VENV)/bin/python
PIP := $(VENV)/bin/pip PIP := $(VENV)/bin/pip
PYTEST := $(VENV)/bin/pytest PYTEST := $(VENV)/bin/pytest
TPU_NAME ?= phantom-tpu
TPU_ZONE ?= us-central2-b
TPU_TYPE ?= v4-32
TPU_RUNTIME ?= tpu-vm-v4-base
TPU_PROJECT ?= phantom-trc
TPU_NETWORK ?= default
TPU_SUBNETWORK ?= default-us-central2
TPU_USE_SPOT ?= 0
TPU_EXTRA_CREATE_FLAGS ?=
TPU_WORKDIR ?= ~/PHANTOM
TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
TPU_VENV ?= .venv-tpu
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
.DEFAULT_GOAL := help .DEFAULT_GOAL := help
.PHONY: help .PHONY: help
help: help:
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines" @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*"
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
$(BUILDDIR): $(BUILDDIR):
mkdir -p paper/$(BUILDDIR) mkdir -p paper/$(BUILDDIR)
@@ -70,6 +88,97 @@ $(VENV):
install: $(VENV) install: $(VENV)
$(PIP) install -r requirements.txt $(PIP) install -r requirements.txt
.PHONY: tpu.setup
tpu.setup:
@command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1)
@gcloud auth login --update-adc
@gcloud auth application-default login
@gcloud config set project "$(TPU_PROJECT)"
.PHONY: tpu.check.zone
tpu.check.zone:
@case "$(TPU_ZONE)" in \
europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \
*) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \
esac
.PHONY: tpu.create.v4.ondemand
tpu.create.v4.ondemand:
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2
.PHONY: tpu.create.v4.spot
tpu.create.v4.spot:
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2
.PHONY: tpu.create
tpu.create: tpu.check.zone
@if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \
echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \
else \
$(TPU_CREATE_CMD); \
fi
.PHONY: tpu.ensure
tpu.ensure: tpu.check.zone
@set -e; \
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \
if [ -z "$$STATE" ]; then \
echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \
$(TPU_CREATE_CMD); \
elif [ "$$STATE" = "READY" ]; then \
echo "TPU VM $(TPU_NAME) is READY"; \
elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \
echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \
$(TPU_CREATE_CMD); \
else \
echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \
exit 1; \
fi
.PHONY: tpu.status
tpu.status:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)"
.PHONY: tpu.ssh
tpu.ssh:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)"
.PHONY: tpu.prepare
tpu.prepare: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)"
.PHONY: tpu.deploy
tpu.deploy: tpu.prepare
@for p in $(TPU_SYNC_PATHS); do \
if [ ! -e "$$p" ]; then continue; fi; \
if [ -d "$$p" ]; then \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
else \
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
fi; \
done
.PHONY: tpu.install
tpu.install: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)'
.PHONY: tpu.check.remote
tpu.check.remote: tpu.ensure
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)'
.PHONY: tpu.train
tpu.train: tpu.check.remote
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)'
.PHONY: tpu.bootstrap
tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
.PHONY: tpu.delete
tpu.delete:
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
.PHONY: stats.lines .PHONY: stats.lines
stats.lines: stats.lines:
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \ @find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \

13
engine/jax/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""JAX-compatible training and environment modules for PHANTOM."""
from __future__ import annotations
try:
import jax # noqa: F401
import jax.numpy as jnp # noqa: F401
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
__all__ = ["JAX_AVAILABLE"]

49
engine/jax/checkpoint.py Normal file
View File

@@ -0,0 +1,49 @@
"""Orbax checkpoint helpers for JAX training runs."""
from __future__ import annotations
from pathlib import Path
from typing import Any
try:
import orbax.checkpoint as ocp
HAS_ORBAX = True
except ImportError:
HAS_ORBAX = False
def _require_orbax() -> None:
if not HAS_ORBAX:
raise ImportError(
"orbax-checkpoint is required for checkpoint support. "
"Install engine/jax/requirements.txt first."
)
def create_manager(directory: str | Path, max_to_keep: int = 5):
_require_orbax()
root = Path(directory)
root.mkdir(parents=True, exist_ok=True)
options = ocp.CheckpointManagerOptions(
max_to_keep=max(1, int(max_to_keep)), create=True
)
return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options)
def save(manager, *, step: int, payload: Any) -> bool:
_require_orbax()
return bool(manager.save(int(step), payload))
def latest_step(manager) -> int | None:
_require_orbax()
return manager.latest_step()
def restore(manager, *, target: Any, step: int | None = None) -> Any:
_require_orbax()
step_to_restore = manager.latest_step() if step is None else int(step)
if step_to_restore is None:
return target
return manager.restore(step_to_restore, items=target)

287
engine/jax/env.py Normal file
View File

@@ -0,0 +1,287 @@
"""JAX-native PHANTOM environment with robust contamination step."""
from __future__ import annotations
from typing import NamedTuple
try:
import jax
import jax.numpy as jnp
except ImportError as exc: # pragma: no cover
raise ImportError("engine.jax.env requires JAX") from exc
from .primitives import (
_sample_sessions_jax,
agent_probability_from_kl,
batch_kl,
compute_session_transitions,
load_transition_data,
purchase_flags,
reward_with_coi_penalty,
revenue_from_demand,
weighted_demand,
)
class EnvParams(NamedTuple):
n_products: int
n_sessions: int
max_episode_steps: int
max_session_steps: int
price_low: float
price_high: float
lambda_coi: float
info_value: float
robust_radius: float
margin_floor: float
margin_floor_patience: int
action_scales: jax.Array
alpha_nominal: float
alpha_candidates: jax.Array
human_T: jax.Array
agent_T: jax.Array
terminal_mask: jax.Array
purchase_mask: jax.Array
event_weights: jax.Array
start_idx: int
term_idx: int
class EnvState(NamedTuple):
prices: jax.Array
demand: jax.Array
step_count: jax.Array
low_margin_streak: jax.Array
last_agent_prob: jax.Array
last_alpha_adv: jax.Array
class CandidateEval(NamedTuple):
reward: jax.Array
revenue: jax.Array
demand: jax.Array
agent_prob: jax.Array
leakage: jax.Array
discount: jax.Array
n_purchases: jax.Array
n_agents: jax.Array
def make_env_params(
*,
n_products: int,
alpha: float,
n_sessions: int,
lambda_coi: float,
robust_radius: float,
robust_points: int,
info_value: float,
action_levels: int,
action_scale_low: float,
action_scale_high: float,
price_low: float,
price_high: float,
max_episode_steps: int,
max_session_steps: int = 40,
margin_floor: float = 0.05,
margin_floor_patience: int = 5,
prefer_behavior_data: bool = True,
) -> EnvParams:
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
if robust_radius <= 0.0 or robust_points <= 1:
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
else:
lo = max(0.0, float(alpha) - float(robust_radius))
hi = min(1.0, float(alpha) + float(robust_radius))
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
action_scales = jnp.linspace(
float(action_scale_low),
float(action_scale_high),
int(action_levels),
dtype=jnp.float32,
)
return EnvParams(
n_products=int(n_products),
n_sessions=int(n_sessions),
max_episode_steps=int(max_episode_steps),
max_session_steps=int(max_session_steps),
price_low=float(price_low),
price_high=float(price_high),
lambda_coi=float(lambda_coi),
info_value=float(info_value),
robust_radius=float(robust_radius),
margin_floor=float(margin_floor),
margin_floor_patience=int(margin_floor_patience),
action_scales=action_scales,
alpha_nominal=float(alpha),
alpha_candidates=alpha_candidates,
human_T=jnp.asarray(transition.human_T),
agent_T=jnp.asarray(transition.agent_T),
terminal_mask=jnp.asarray(transition.terminal_mask),
purchase_mask=jnp.asarray(transition.purchase_mask),
event_weights=jnp.asarray(transition.event_weights),
start_idx=int(transition.start_idx),
term_idx=int(transition.term_idx),
)
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
def _decode_action(
prices: jax.Array, action: jax.Array, params: EnvParams
) -> jax.Array:
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
scale = params.action_scales[idx]
next_prices = prices * scale
return jnp.clip(next_prices, params.price_low, params.price_high)
def _evaluate_candidate(
key: jax.Array,
alpha_candidate: jax.Array,
prices: jax.Array,
params: EnvParams,
) -> CandidateEval:
states, products, actors, lengths = _sample_sessions_jax(
key,
params.human_T,
params.agent_T,
params.terminal_mask,
params.start_idx,
params.term_idx,
alpha_candidate,
params.n_products,
params.n_sessions,
params.max_session_steps,
int(params.human_T.shape[0]),
)
session_trans = compute_session_transitions(
states, lengths, int(params.human_T.shape[0])
)
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
agent_probs = agent_probability_from_kl(delta_h, delta_a)
agent_prob = jnp.mean(agent_probs)
demand = weighted_demand(states, products, params.n_products, params.event_weights)
revenue = revenue_from_demand(prices, demand)
reward, leakage, discount = reward_with_coi_penalty(
revenue,
agent_prob,
params.lambda_coi,
params.info_value,
)
purchases = purchase_flags(states, params.purchase_mask)
return CandidateEval(
reward=reward,
revenue=revenue,
demand=demand,
agent_prob=agent_prob,
leakage=leakage,
discount=discount,
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
n_agents=jnp.sum(actors.astype(jnp.float32)),
)
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
prices = jax.random.uniform(
key,
shape=(params.n_products,),
minval=params.price_low,
maxval=params.price_high,
)
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
state = EnvState(
prices=prices,
demand=demand,
step_count=jnp.asarray(0, dtype=jnp.int32),
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
)
return _flatten_obs(demand, prices), state
def step_env(
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams,
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
prices = _decode_action(state.prices, action, params)
n_candidates = params.alpha_candidates.shape[0]
cand_keys = jax.random.split(key, n_candidates)
evals = jax.vmap(
lambda k, a: _evaluate_candidate(k, a, prices, params),
in_axes=(0, 0),
)(cand_keys, params.alpha_candidates)
idx = jnp.argmin(evals.reward)
demand = evals.demand[idx]
reward = evals.reward[idx]
revenue = evals.revenue[idx]
agent_prob = evals.agent_prob[idx]
leakage = evals.leakage[idx]
discount = evals.discount[idx]
n_purchases = evals.n_purchases[idx]
n_agents = evals.n_agents[idx]
alpha_adv = params.alpha_candidates[idx]
step_count = state.step_count + 1
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
avg_margin = (avg_price - params.price_low) / avg_price
next_streak = jnp.where(
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
)
margin_collapsed = next_streak >= params.margin_floor_patience
done = (step_count >= params.max_episode_steps) | margin_collapsed
next_state = EnvState(
prices=prices,
demand=demand,
step_count=step_count,
low_margin_streak=next_streak,
last_agent_prob=agent_prob,
last_alpha_adv=alpha_adv,
)
obs = _flatten_obs(demand, prices)
info = {
"revenue": revenue,
"agent_prob": agent_prob,
"alpha_adv": alpha_adv,
"coi_leakage": leakage,
"coi_discount": discount,
"n_purchases": n_purchases,
"n_agents": n_agents,
"avg_margin": avg_margin,
}
return obs, next_state, reward, done, info
class PHANTOMJAXEnv:
def __init__(self, params: EnvParams):
self.params = params
def reset(self, key: jax.Array, params: EnvParams | None = None):
return reset_env(key, self.params if params is None else params)
def step(
self,
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams | None = None,
):
return step_env(key, state, action, self.params if params is None else params)
def action_space_n(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.action_scales.shape[0])
def observation_dim(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.n_products * 2)

493
engine/jax/primitives.py Normal file
View File

@@ -0,0 +1,493 @@
"""JAX-compatible primitives for PHANTOM session simulation and separability."""
from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import Mapping, Sequence
import numpy as np
try:
import jax
import jax.numpy as jnp
JAX_AVAILABLE = True
except ImportError:
jax = None # type: ignore[assignment]
jnp = np # type: ignore[assignment]
JAX_AVAILABLE = False
STATE_START_KEYS = ("session_start", "start")
TERMINAL_EVENT_TOKENS = (
"session_end",
"end",
"purchase_complete",
"checkout_start",
"checkout",
)
PURCHASE_EVENT_TOKENS = (
"purchase_complete",
"purchase",
"checkout_start",
"checkout",
)
CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5}
ACTION_CATEGORIES = {
"cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"},
"dwell": {
"hover_title",
"hover_paragraph",
"hover_link",
"hover_over_title",
"hover_over_paragraph",
"hover_over_link",
"hover_over_button",
},
"nav": {
"page_view",
"view_item",
"view",
"learn_more",
"learn_more_about_item",
"view_item_page",
"session_start",
},
"filter": {
"search",
"filter_date",
"filter_price",
"sort",
"filter_for_date",
"filter_for_price",
"filter_for_amenities",
"sort_change",
},
}
DEFAULT_ACTION_WEIGHTS = {
action: CATEGORY_WEIGHTS[group]
for group, actions in ACTION_CATEGORIES.items()
for action in actions
}
@dataclass(frozen=True)
class TransitionData:
"""Dense transition kernels and per-state metadata."""
human_T: np.ndarray
agent_T: np.ndarray
terminal_mask: np.ndarray
purchase_mask: np.ndarray
event_weights: np.ndarray
event_names: tuple[str, ...]
start_idx: int
term_idx: int
def to_jax(self) -> "TransitionData":
if not JAX_AVAILABLE:
return self
return TransitionData(
human_T=jnp.asarray(self.human_T),
agent_T=jnp.asarray(self.agent_T),
terminal_mask=jnp.asarray(self.terminal_mask),
purchase_mask=jnp.asarray(self.purchase_mask),
event_weights=jnp.asarray(self.event_weights),
event_names=self.event_names,
start_idx=int(self.start_idx),
term_idx=int(self.term_idx),
)
@dataclass(frozen=True)
class SessionBatch:
states: np.ndarray
products: np.ndarray
actors: np.ndarray
lengths: np.ndarray
def _event_weight(name: str) -> float:
if name in DEFAULT_ACTION_WEIGHTS:
return float(DEFAULT_ACTION_WEIGHTS[name])
if name.startswith("hover"):
return float(CATEGORY_WEIGHTS["dwell"])
if name.startswith("filter") or name in {"search", "sort", "sort_change"}:
return float(CATEGORY_WEIGHTS["filter"])
if name.startswith("add") or name in {
"checkout",
"checkout_start",
"purchase",
"remove_item",
"purchase_complete",
}:
return float(CATEGORY_WEIGHTS["cart"])
if any(token in name for token in TERMINAL_EVENT_TOKENS):
return 0.0
return float(CATEGORY_WEIGHTS["nav"])
def _is_terminal(name: str) -> bool:
return any(token in name for token in TERMINAL_EVENT_TOKENS)
def _is_purchase(name: str) -> bool:
return any(token in name for token in PURCHASE_EVENT_TOKENS)
def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]:
names: set[str] = set()
for trans in transitions:
for src, dsts in trans.items():
names.add(src)
names.update(dsts.keys())
names.discard("__terminal__")
return tuple(sorted(names))
def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray:
row_sums = matrix.sum(axis=1, keepdims=True)
dead_rows = np.isclose(row_sums.squeeze(-1), 0.0)
if np.any(dead_rows):
matrix[dead_rows] = 0.0
matrix[dead_rows, term_idx] = 1.0
row_sums = matrix.sum(axis=1, keepdims=True)
return matrix / np.maximum(row_sums, 1e-8)
def _dense_from_dict(
transitions: Mapping[str, Mapping[str, float]],
event_to_idx: Mapping[str, int],
term_idx: int,
) -> np.ndarray:
n_states = len(event_to_idx)
matrix = np.zeros((n_states, n_states), dtype=np.float32)
for src, dsts in transitions.items():
i = event_to_idx.get(src)
if i is None:
continue
for dst, prob in dsts.items():
j = event_to_idx.get(dst)
if j is None:
continue
matrix[i, j] += float(prob)
return _normalize_rows(matrix, term_idx)
def compile_transition_data(
human_transitions: Mapping[str, Mapping[str, float]],
agent_transitions: Mapping[str, Mapping[str, float]],
) -> TransitionData:
event_names = _collect_events(human_transitions, agent_transitions)
if not event_names:
return fallback_transition_data()
event_names = tuple([*event_names, "__terminal__"])
term_idx = len(event_names) - 1
event_to_idx = {name: i for i, name in enumerate(event_names)}
human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx)
agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx)
terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool)
purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool)
event_weights = np.array(
[_event_weight(name) for name in event_names], dtype=np.float32
)
terminal_mask[term_idx] = True
for idx, is_term in enumerate(terminal_mask):
if not is_term:
continue
human_T[idx] = 0.0
agent_T[idx] = 0.0
human_T[idx, idx] = 1.0
agent_T[idx, idx] = 1.0
start_idx = 0
for key in STATE_START_KEYS:
if key in event_to_idx:
start_idx = int(event_to_idx[key])
break
return TransitionData(
human_T=human_T,
agent_T=agent_T,
terminal_mask=terminal_mask,
purchase_mask=purchase_mask,
event_weights=event_weights,
event_names=event_names,
start_idx=start_idx,
term_idx=term_idx,
)
def fallback_transition_data() -> TransitionData:
human = {
"session_start": {
"page_view": 0.80,
"view_item_page": 0.15,
"session_end": 0.05,
},
"page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20},
"view_item_page": {
"learn_more_about_item": 0.40,
"add_item_to_cart": 0.28,
"session_end": 0.32,
},
"learn_more_about_item": {
"add_item_to_cart": 0.50,
"view_item_page": 0.30,
"session_end": 0.20,
},
"add_item_to_cart": {
"checkout_start": 0.58,
"view_item_page": 0.24,
"session_end": 0.18,
},
"checkout_start": {"purchase_complete": 0.70, "session_end": 0.30},
"purchase_complete": {"session_end": 1.0},
}
agent = {
"session_start": {
"page_view": 0.90,
"view_item_page": 0.08,
"session_end": 0.02,
},
"page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25},
"view_item_page": {
"learn_more_about_item": 0.55,
"add_item_to_cart": 0.15,
"session_end": 0.30,
},
"learn_more_about_item": {
"view_item_page": 0.45,
"add_item_to_cart": 0.20,
"session_end": 0.35,
},
"add_item_to_cart": {
"checkout_start": 0.42,
"view_item_page": 0.28,
"session_end": 0.30,
},
"checkout_start": {"purchase_complete": 0.52, "session_end": 0.48},
"purchase_complete": {"session_end": 1.0},
}
return compile_transition_data(human, agent)
def load_transition_data(prefer_data: bool = True) -> TransitionData:
if not prefer_data:
return fallback_transition_data()
try:
from ..lib.behavior import get_transition_models
human_trans, agent_trans = get_transition_models()
return compile_transition_data(human_trans, agent_trans)
except Exception:
return fallback_transition_data()
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(8, 9, 10))
def _sample_sessions_jax(
key: jax.Array,
human_T: jax.Array,
agent_T: jax.Array,
terminal_mask: jax.Array,
start_idx: int,
term_idx: int,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3)
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint(
k_product, (n_sessions,), 0, n_products, dtype=jnp.int32
)
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
def _scan_step(carry, _):
states, active, rng = carry
rng, k = jax.random.split(rng)
probs_h = human_T[states]
probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, int(term_idx))
emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, int(term_idx))
return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan(
_scan_step, (state_init, active_init, k_step), None, length=max_steps
)
states = state_t.T
lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32)
return states, products, actors, lengths
def sample_sessions(
key,
transition_data: TransitionData,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
) -> SessionBatch:
if JAX_AVAILABLE:
td = transition_data.to_jax()
states, products, actors, lengths = _sample_sessions_jax(
key,
td.human_T,
td.agent_T,
td.terminal_mask,
int(td.start_idx),
int(td.term_idx),
float(alpha),
int(n_products),
int(n_sessions),
int(max_steps),
int(td.human_T.shape[0]),
)
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0]))
n_states = transition_data.human_T.shape[0]
products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32)
actors = (rng.random(size=n_sessions) < alpha).astype(np.int32)
states = np.full((n_sessions, max_steps), -1, dtype=np.int32)
lengths = np.zeros((n_sessions,), dtype=np.int32)
for i in range(n_sessions):
current = int(transition_data.start_idx)
mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T
for t in range(max_steps):
nxt = int(rng.choice(n_states, p=mat[current]))
states[i, t] = nxt
if transition_data.terminal_mask[nxt]:
lengths[i] = t + 1
break
current = nxt
if lengths[i] == 0:
lengths[i] = max_steps
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(2,))
def compute_session_transitions(states, lengths, n_states: int):
src = states[:, :-1]
dst = states[:, 1:]
time_idx = jnp.arange(src.shape[1])[None, :]
valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1))
src_clip = jnp.clip(src, 0, n_states - 1)
dst_clip = jnp.clip(dst, 0, n_states - 1)
src_oh = jax.nn.one_hot(src_clip, n_states)
dst_oh = jax.nn.one_hot(dst_clip, n_states)
counts = jnp.einsum(
"nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32)
)
row_sums = jnp.sum(counts, axis=-1, keepdims=True)
return counts / (row_sums + 1e-10)
else:
def compute_session_transitions(states, lengths, n_states: int):
trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32)
for i in range(states.shape[0]):
for t in range(max(int(lengths[i]) - 1, 0)):
s = int(states[i, t])
d = int(states[i, t + 1])
if s >= 0 and d >= 0:
trans[i, s, d] += 1.0
row_sums = trans.sum(axis=-1, keepdims=True)
return trans / (row_sums + 1e-10)
def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10):
p = P + eps
p = p / jnp.sum(p, axis=-1, keepdims=True)
qh = Q_human[None, ...] + eps
qa = Q_agent[None, ...] + eps
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
return delta_h, delta_a
if JAX_AVAILABLE:
batch_kl = jax.jit(batch_kl)
def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0):
t = jnp.maximum(float(temperature), 1e-6)
exp_h = jnp.exp(-delta_h / t)
exp_a = jnp.exp(-delta_a / t)
return exp_a / (exp_h + exp_a + 1e-10)
def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0):
logits = beta * (delta_h - delta_a)
return 1.0 / (1.0 + jnp.exp(-logits))
def weighted_demand(states, products, n_products: int, event_weights):
valid = states >= 0
state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1)
weights = event_weights[state_clip] * valid
per_session = jnp.sum(weights, axis=1)
demand = jnp.zeros((n_products,), dtype=jnp.float32)
demand = demand.at[products].add(per_session)
total = jnp.sum(demand)
return jnp.where(total > 0.0, (demand / total) * 100.0, demand)
if JAX_AVAILABLE:
weighted_demand = jax.jit(weighted_demand, static_argnums=(2,))
def purchase_flags(states, purchase_mask):
state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1)
hits = purchase_mask[state_clip] & (states >= 0)
return jnp.any(hits, axis=1)
if JAX_AVAILABLE:
purchase_flags = jax.jit(purchase_flags)
def revenue_from_demand(prices, demand):
return jnp.dot(prices, demand)
if JAX_AVAILABLE:
revenue_from_demand = jax.jit(revenue_from_demand)
def reward_with_coi_penalty(
revenue, agent_prob: float, lambda_coi: float, info_value: float
):
leakage = agent_prob * info_value
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
return revenue * discount, leakage, discount
if JAX_AVAILABLE:
reward_with_coi_penalty = jax.jit(reward_with_coi_penalty)

View File

@@ -0,0 +1,5 @@
flax>=0.8.0
optax>=0.2.0
distrax>=0.1.5
orbax-checkpoint>=0.5.0
chex>=0.1.8

471
engine/jax/train.py Normal file
View File

@@ -0,0 +1,471 @@
"""Pure JAX PPO trainer for the PHANTOM environment."""
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
try:
import jax
import jax.numpy as jnp
import distrax
import flax.linen as nn
import optax
from flax import serialization
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
HAS_JAX_STACK = True
except ImportError:
jax = None # type: ignore[assignment]
jnp = None # type: ignore[assignment]
distrax = None # type: ignore[assignment]
optax = None # type: ignore[assignment]
serialization = None # type: ignore[assignment]
class _ModuleStub:
pass
class _NNStub:
Module = _ModuleStub
@staticmethod
def compact(fn):
return fn
nn = _NNStub() # type: ignore[assignment]
def constant(*_args, **_kwargs): # type: ignore[override]
return None
def orthogonal(*_args, **_kwargs): # type: ignore[override]
return None
class TrainState: # type: ignore[override]
pass
HAS_JAX_STACK = False
from .env import PHANTOMJAXEnv, make_env_params
class ActorCritic(nn.Module):
action_dim: int
activation: str = "tanh"
@nn.compact
def __call__(self, x):
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
actor = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(x)
actor = activation_fn(actor)
actor = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(actor)
actor = activation_fn(actor)
logits = nn.Dense(
self.action_dim,
kernel_init=orthogonal(0.01),
bias_init=constant(0.0),
)(actor)
critic = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(x)
critic = activation_fn(critic)
critic = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(critic)
critic = activation_fn(critic)
value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)
return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1)
class Transition(NamedTuple):
done: jax.Array
action: jax.Array
value: jax.Array
reward: jax.Array
log_prob: jax.Array
obs: jax.Array
info: dict[str, jax.Array]
def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
out = {
"algo": str(cfg.get("algo", "ppo")).lower(),
"seed": int(cfg.get("seed", 42)),
"learning_rate": float(cfg.get("learning_rate", 3e-4)),
"gamma": float(cfg.get("gamma", 0.99)),
"gae_lambda": float(cfg.get("gae_lambda", 0.95)),
"clip_range": float(cfg.get("clip_range", 0.2)),
"ent_coef": float(cfg.get("ent_coef", 0.01)),
"vf_coef": float(cfg.get("vf_coef", 0.5)),
"max_grad_norm": float(cfg.get("max_grad_norm", 0.5)),
"activation": str(cfg.get("activation", "relu")),
"total_timesteps": int(cfg.get("total_timesteps", 50_000)),
"eval_episodes": int(cfg.get("eval_episodes", 5)),
"model_dir": str(cfg.get("model_dir", "engine/models")),
"n_products": int(cfg.get("n_products", 10)),
"N": int(cfg.get("N", 100)),
"alpha": float(cfg.get("alpha", 0.3)),
"lambda_coi": float(cfg.get("lambda_coi", 0.2)),
"robust_radius": float(cfg.get("robust_radius", 0.15)),
"robust_points": int(cfg.get("robust_points", 5)),
"info_value": float(cfg.get("info_value", 1.0)),
"price_low": float(cfg.get("price_low", 10.0)),
"price_high": float(cfg.get("price_high", 150.0)),
"action_levels": int(cfg.get("action_levels", 9)),
"action_scale_low": float(cfg.get("action_scale_low", 0.8)),
"action_scale_high": float(cfg.get("action_scale_high", 1.2)),
"max_episode_steps": int(cfg.get("max_steps", 100)),
"max_session_steps": int(cfg.get("max_session_steps", 40)),
"margin_floor": float(cfg.get("margin_floor", 0.05)),
"margin_floor_patience": int(cfg.get("margin_floor_patience", 5)),
"prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)),
"num_envs": int(cfg.get("jax_num_envs", 16)),
"num_steps": int(cfg.get("jax_num_steps", 128)),
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
}
rollout = out["num_envs"] * out["num_steps"]
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1))
return out
def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array:
mask = done
while mask.ndim < keep.ndim:
mask = mask[..., None]
return jnp.where(mask, reset, keep)
def make_train(config: dict[str, Any]):
cfg = _jax_cfg(config)
env_params = make_env_params(
n_products=cfg["n_products"],
alpha=cfg["alpha"],
n_sessions=cfg["N"],
lambda_coi=cfg["lambda_coi"],
robust_radius=cfg["robust_radius"],
robust_points=cfg["robust_points"],
info_value=cfg["info_value"],
action_levels=cfg["action_levels"],
action_scale_low=cfg["action_scale_low"],
action_scale_high=cfg["action_scale_high"],
price_low=cfg["price_low"],
price_high=cfg["price_high"],
max_episode_steps=cfg["max_episode_steps"],
max_session_steps=cfg["max_session_steps"],
margin_floor=cfg["margin_floor"],
margin_floor_patience=cfg["margin_floor_patience"],
prefer_behavior_data=cfg["prefer_behavior_data"],
)
env = PHANTOMJAXEnv(env_params)
network = ActorCritic(env.action_space_n(), activation=cfg["activation"])
def linear_schedule(count: jax.Array) -> jax.Array:
updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"])
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
return cfg["learning_rate"] * frac
def train(rng: jax.Array):
rng, init_key = jax.random.split(rng)
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
params = network.init(init_key, init_obs)
if cfg["anneal_lr"]:
tx = optax.chain(
optax.clip_by_global_norm(cfg["max_grad_norm"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(cfg["max_grad_norm"]),
optax.adam(cfg["learning_rate"], eps=1e-5),
)
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
rng, reset_key = jax.random.split(rng)
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
obs, env_state = jax.vmap(env.reset)(reset_keys)
def _update_step(runner_state, _):
def _env_step(runner_state, _):
train_state, env_state, last_obs, rng = runner_state
rng, action_key = jax.random.split(rng)
policy, value = network.apply(train_state.params, last_obs)
action = policy.sample(seed=action_key)
log_prob = policy.log_prob(action)
rng, step_key = jax.random.split(rng)
step_keys = jax.random.split(step_key, cfg["num_envs"])
nxt_obs, nxt_state, reward, done, info = jax.vmap(
env.step,
in_axes=(0, 0, 0),
)(step_keys, env_state, action)
rng, reset_key = jax.random.split(rng)
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
env_next = jax.tree_util.tree_map(
lambda keep, reset: _select_env_state(done, keep, reset),
nxt_state,
rst_state,
)
transition = Transition(
done=done,
action=action,
value=value,
reward=reward,
log_prob=log_prob,
obs=last_obs,
info=info,
)
return (train_state, env_next, obs_next, rng), transition
runner_state, traj_batch = jax.lax.scan(
_env_step,
runner_state,
None,
length=cfg["num_steps"],
)
train_state, env_state, last_obs, rng = runner_state
_, last_value = network.apply(train_state.params, last_obs)
def _compute_gae(traj_batch, last_value):
def _gae_step(carry, transition):
gae, next_value = carry
delta = (
transition.reward
+ cfg["gamma"] * next_value * (1.0 - transition.done)
- transition.value
)
gae = (
delta
+ cfg["gamma"]
* cfg["gae_lambda"]
* (1.0 - transition.done)
* gae
)
return (gae, transition.value), gae
_, advantages = jax.lax.scan(
_gae_step,
(jnp.zeros_like(last_value), last_value),
traj_batch,
reverse=True,
unroll=16,
)
targets = advantages + traj_batch.value
return advantages, targets
advantages, targets = _compute_gae(traj_batch, last_value)
def _update_epoch(update_state, _):
def _update_minibatch(train_state, batch_info):
traj_b, adv_b, tgt_b = batch_info
def _loss_fn(params, traj_b, adv_b, tgt_b):
policy, value = network.apply(params, traj_b.obs)
log_prob = policy.log_prob(traj_b.action)
value_clipped = traj_b.value + (value - traj_b.value).clip(
-cfg["clip_range"], cfg["clip_range"]
)
value_loss = (
0.5
* jnp.maximum(
jnp.square(value - tgt_b),
jnp.square(value_clipped - tgt_b),
).mean()
)
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
ratio = jnp.exp(log_prob - traj_b.log_prob)
loss_actor = -jnp.minimum(
ratio * adv_norm,
jnp.clip(
ratio,
1.0 - cfg["clip_range"],
1.0 + cfg["clip_range"],
)
* adv_norm,
).mean()
entropy = policy.entropy().mean()
total_loss = (
loss_actor
+ cfg["vf_coef"] * value_loss
- cfg["ent_coef"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
train_state = train_state.apply_gradients(grads=grads)
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
train_state, traj_batch, advantages, targets, rng = update_state
rng, perm_key = jax.random.split(rng)
batch_size = cfg["num_envs"] * cfg["num_steps"]
permutation = jax.random.permutation(perm_key, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]),
batch,
)
shuffled = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0),
batch,
)
minibatches = jax.tree_util.tree_map(
lambda x: x.reshape(
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
),
shuffled,
)
train_state, _ = jax.lax.scan(
_update_minibatch, train_state, minibatches
)
return (train_state, traj_batch, advantages, targets, rng), None
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, _ = jax.lax.scan(
_update_epoch,
update_state,
None,
length=cfg["update_epochs"],
)
train_state = update_state[0]
rng = update_state[-1]
metric = {
"reward": jnp.mean(traj_batch.reward),
"revenue": jnp.mean(traj_batch.info["revenue"]),
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
}
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
runner_state = (train_state, env_state, obs, rng)
runner_state, metric = jax.lax.scan(
_update_step,
runner_state,
None,
length=cfg["num_updates"],
)
return {
"runner_state": runner_state,
"metrics": metric,
}
return train, network, env, cfg
def evaluate_policy(
*,
network: ActorCritic,
params: Any,
env: PHANTOMJAXEnv,
episodes: int,
seed: int,
) -> dict[str, float]:
rewards: list[float] = []
revenues: list[float] = []
key = jax.random.PRNGKey(seed)
for _ in range(int(episodes)):
key, reset_key = jax.random.split(key)
obs, state = env.reset(reset_key)
ep_reward = 0.0
ep_revenue = 0.0
done = False
steps = 0
while not done and steps < int(env.params.max_episode_steps):
policy, _ = network.apply(params, obs)
action = jnp.argmax(policy.logits)
key, step_key = jax.random.split(key)
obs, state, reward, done_flag, info = env.step(step_key, state, action)
ep_reward += float(np.asarray(reward))
ep_revenue += float(np.asarray(info["revenue"]))
done = bool(np.asarray(done_flag))
steps += 1
rewards.append(ep_reward)
revenues.append(ep_revenue)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
if not HAS_JAX_STACK:
raise ImportError(
"JAX PPO path requires jax, flax, optax, and distrax. "
"Install engine/jax/requirements.txt on this machine first."
)
run_cfg = _jax_cfg(cfg)
if run_cfg["algo"] != "ppo":
raise ValueError(
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
)
train_fn, network, env, run_cfg = make_train(run_cfg)
train_jit = jax.jit(train_fn)
rng = jax.random.PRNGKey(run_cfg["seed"])
out = train_jit(rng)
train_state = out["runner_state"][0]
metric = out["metrics"]
metrics = {
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
"train/global_step": int(
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
),
}
eval_metrics = evaluate_policy(
network=network,
params=train_state.params,
env=env,
episodes=run_cfg["eval_episodes"],
seed=run_cfg["seed"] + 7,
)
metrics.update(eval_metrics)
model_dir = Path(run_cfg["model_dir"])
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / "phantom_ppo_jax.msgpack"
model_path.write_bytes(serialization.to_bytes(train_state.params))
metrics["model/path"] = str(model_path)
return {"params": train_state.params}, metrics

119
engine/lib/callbacks.py Normal file
View File

@@ -0,0 +1,119 @@
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import numpy as np
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
class MetricsCallback(BaseCallback):
"""Training metrics logger - reads info['economics'], logs to W&B."""
def __init__(
self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0
):
super().__init__(verbose)
self.log_histograms = log_histograms
self.log_freq = log_freq
self._episode_revenues: list[float] = []
def _on_step(self) -> bool:
if not HAS_WANDB or wandb.run is None:
return True
for info in self.locals.get("infos", []):
if "economics" not in info:
continue
econ = info["economics"]
t = self.num_timesteps
payload = {
"economics/revenue": econ["revenue"],
"economics/margin": econ["margin"],
"coi/level": econ["coi_level"],
"economics/regret": econ["regret"],
}
if "coi_mix" in econ:
payload["coi/mix"] = econ["coi_mix"]
if "coi_base" in econ:
payload["coi/base"] = econ["coi_base"]
if "coi_leakage" in econ:
payload["coi/leakage"] = econ["coi_leakage"]
if "coi_penalty" in econ:
payload["coi/penalty"] = econ["coi_penalty"]
wandb.log(payload, step=t)
self._episode_revenues.append(econ["revenue"])
# histograms at log_freq intervals
if self.log_histograms and self.num_timesteps % self.log_freq == 0:
for info in self.locals.get("infos", []):
if "prices" in info:
wandb.log(
{"distributions/prices": wandb.Histogram(info["prices"])},
step=self.num_timesteps,
)
if "demand" in info:
wandb.log(
{"distributions/demand": wandb.Histogram(info["demand"])},
step=self.num_timesteps,
)
return True
def _on_rollout_end(self) -> None:
if not HAS_WANDB or wandb.run is None or not self._episode_revenues:
return
wandb.log(
{
"episode/mean_revenue": np.mean(self._episode_revenues),
"episode/total_revenue": np.sum(self._episode_revenues),
},
step=self.num_timesteps,
)
self._episode_revenues = []
class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation - true performance without exploration noise."""
def __init__(
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
):
super().__init__(
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
)
self._eval_revenues: list[float] = []
def _on_step(self) -> bool:
result = super()._on_step()
if not HAS_WANDB or wandb.run is None:
return result
# log eval metrics after evaluation runs
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
wandb.log(
{
"eval/mean_reward": self.last_mean_reward,
"eval/mean_revenue": np.mean(self._eval_revenues)
if self._eval_revenues
else 0,
},
step=self.num_timesteps,
)
self._eval_revenues = []
return result
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
# called after each eval episode
info = locals_.get("info", {})
if "economics" in info:
self._eval_revenues.append(info["economics"]["revenue"])

76
engine/lib/coi.py Normal file
View File

@@ -0,0 +1,76 @@
import numpy as np
from typing import Dict
def compute_agent_probability(
trajectory: list, human_transitions: Dict, agent_transitions: Dict
) -> float:
"""estimate agent probability via KL divergence between trajectory transitions and reference models
compares empirical trajectory transition distribution to human/agent prototypes
args:
trajectory: list of state/event strings from session
human_transitions: reference transition dict from human MDP (event->event->prob)
agent_transitions: reference transition dict from agent MDP (event->event->prob)
returns:
agent probability in [0, 1] via softmax over KL divergences
"""
if len(trajectory) < 2:
return 0.0 # insufficient data, assume human
# build empirical transition distribution from trajectory
trans_counts = {}
for s, s_next in zip(trajectory[:-1], trajectory[1:]):
if s not in trans_counts:
trans_counts[s] = {}
trans_counts[s][s_next] = trans_counts[s].get(s_next, 0) + 1
# normalize to probabilities
empirical = {}
for s, nxt in trans_counts.items():
total = sum(nxt.values())
empirical[s] = {s_n: cnt / total for s_n, cnt in nxt.items()}
# compute KL divergence to each prototype
def kl_div(p_dist: Dict, q_dist: Dict) -> float:
eps = 1e-10
# aggregate over all source states in empirical dist
kl = 0.0
for s in p_dist:
if s not in q_dist:
continue # skip states not in reference
p_trans, q_trans = p_dist[s], q_dist[s]
for k in p_trans:
p_val = p_trans[k] + eps
q_val = q_trans.get(k, 0.0) + eps
kl += p_val * np.log(p_val / q_val)
return kl
kl_human = kl_div(empirical, human_transitions)
kl_agent = kl_div(empirical, agent_transitions)
# convert to probability via softmax (lower KL = higher prob)
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent))
exp_h = np.exp(-kl_human)
exp_a = np.exp(-kl_agent)
return float(exp_a / (exp_h + exp_a + 1e-10))
def extract_purchases(trajectories: list) -> Dict[int, int]:
purchases: Dict[int, int] = {}
for traj in trajectories:
if traj and "checkout" in traj[-1] and "_product" in traj[-1]:
prod_id = int(traj[-1].rsplit("_product", 1)[1])
purchases[prod_id] = purchases.get(prod_id, 0) + 1
return purchases
def compute_uplift_coi(
prices: np.ndarray, purchases: Dict[int, int], baseline_prices: np.ndarray
) -> float:
# TODO: consider view-weighted fractional purchase for denser signal
return float(
sum(max(0.0, prices[k] - baseline_prices[k]) * n for k, n in purchases.items())
)

70
engine/lib/discrete.py Normal file
View File

@@ -0,0 +1,70 @@
from collections import defaultdict
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class DiscretePriceActionWrapper(gym.ActionWrapper):
def __init__(
self,
env: gym.Env,
n_levels: int = 9,
min_scale: float = 0.8,
max_scale: float = 1.2,
):
super().__init__(env)
self.scales = np.linspace(min_scale, max_scale, n_levels, dtype=np.float32)
self.action_space = spaces.Discrete(n_levels)
def action(self, action: int):
scale = float(self.scales[int(action)])
cur = np.asarray(self.env.unwrapped._prices, dtype=np.float32)
lo, hi = self.env.unwrapped.price_bounds
return np.clip(cur * scale, lo, hi).astype(np.float32)
class EventQTable:
def __init__(
self,
n_actions: int,
n_products: int,
price_bounds: tuple,
lr: float = 0.1,
gamma: float = 0.99,
n_bins: int = 6,
):
self.n_actions = int(n_actions)
self.n_products = int(n_products)
self.lr = float(lr)
self.gamma = float(gamma)
self.q = defaultdict(lambda: np.zeros(self.n_actions, dtype=np.float32))
lo, hi = price_bounds
self.demand_bins = np.linspace(0.0, 100.0, n_bins + 1)[1:-1]
self.price_bins = np.linspace(lo, hi, n_bins + 1)[1:-1]
def encode(self, obs: np.ndarray) -> tuple:
obs = np.asarray(obs, dtype=np.float32)
d = obs[: self.n_products]
p = obs[self.n_products : 2 * self.n_products]
d_mean = float(np.mean(d)) if d.size else 0.0
d_std = float(np.std(d)) if d.size else 0.0
p_mean = float(np.mean(p)) if p.size else 0.0
return (
int(np.digitize(d_mean, self.demand_bins)),
int(np.digitize(d_std, self.demand_bins)),
int(np.digitize(p_mean, self.price_bins)),
)
def act(self, obs: np.ndarray, eps: float = 0.0) -> tuple[int, tuple]:
s = self.encode(obs)
if np.random.random() < eps:
return int(np.random.randint(self.n_actions)), s
return int(np.argmax(self.q[s])), s
def update(self, s: tuple, a: int, r: float, s2: tuple, done: bool):
target = r + (0.0 if done else self.gamma * float(np.max(self.q[s2])))
self.q[s][a] += self.lr * (target - self.q[s][a])
def predict(self, obs: np.ndarray, deterministic: bool = True):
a, _ = self.act(obs, 0.0 if deterministic else 0.05)
return a, None

182
engine/lib/providers.py Normal file
View File

@@ -0,0 +1,182 @@
"""Provider benchmarking - compare pricing strategies across contamination levels."""
from dataclasses import dataclass, field
from typing import Callable, Any
import numpy as np
import pandas as pd
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
class RandomBaseline:
"""uniform random action selection as a lower-bound baseline"""
def __init__(self, n_actions: int):
self.n = n_actions
def __call__(self, obs):
return int(np.random.randint(self.n))
def predict(self, obs, **kw):
return self(obs), None
class SurgeBaseline:
"""heuristic surge pricing: boost price when demand is above threshold, discount when below.
matches the naive pricing rule from thesis Section 3.3.2"""
def __init__(
self, n_actions: int, high_threshold: float = 60.0, low_threshold: float = 30.0
):
self.n = n_actions
self.mid = n_actions // 2 # identity action (scale ~1.0)
self.high_t = high_threshold
self.low_t = low_threshold
def __call__(self, obs):
obs = np.asarray(obs, dtype=np.float32)
n_prod = len(obs) // 2
demand_mean = float(np.mean(obs[:n_prod])) if n_prod > 0 else 0.0
if demand_mean >= self.high_t:
return min(self.mid + 2, self.n - 1) # surge: two levels above identity
if demand_mean <= self.low_t:
return max(self.mid - 2, 0) # discount: two levels below identity
return self.mid # hold
def predict(self, obs, **kw):
return self(obs), None
@dataclass
class ProviderResult:
"""Single benchmark result for one provider at one alpha level."""
name: str
alpha: float
total_revenue: float
mean_revenue: float
coi_level: float
coi_preserved_pct: float # vs alpha=0 baseline
margin_integrity: float
regret: float
episodes: int
@dataclass
class BenchmarkConfig:
"""Configuration for provider benchmark runs."""
n_episodes: int = 100
alpha_range: list[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5])
baseline_name: str = "fixed"
class ProviderBenchmark:
"""Compare pricing providers to prove margin preservation across contamination levels.
Usage:
def env_factory(alpha):
return EconomicMetricsWrapper(PHANTOM(alpha=alpha))
providers = {
"fixed": lambda obs: np.ones(10) * 50,
"learned": model.predict,
}
benchmark = ProviderBenchmark(env_factory, providers)
results = benchmark.run()
print(benchmark.summary_table())
"""
def __init__(
self,
env_factory: Callable[[float], Any],
providers: dict[str, Callable],
config: BenchmarkConfig | None = None,
):
self.env_factory = env_factory # fn(alpha) -> wrapped env
self.providers = providers # {name: fn(obs) -> action}
self.config = config or BenchmarkConfig()
self.results: list[ProviderResult] = []
def run(self) -> list[ProviderResult]:
"""Run benchmark across all providers and alpha levels."""
baseline_coi: dict[str, float] = {} # {provider: coi at alpha=0}
for alpha in self.config.alpha_range:
env = self.env_factory(alpha)
for name, policy_fn in self.providers.items():
revenues, coi_levels, margins = [], [], []
for _ in range(self.config.n_episodes):
obs, _ = env.reset()
episode_revenue = 0.0
done = False
while not done:
action = policy_fn(obs)
# handle sb3 model.predict returning tuple
if isinstance(action, tuple):
action = action[0]
obs, reward, term, trunc, info = env.step(action)
done = term or trunc
econ = info.get("economics", {})
episode_revenue += econ.get("revenue", 0)
coi_levels.append(econ.get("coi_level", 0))
margins.append(econ.get("margin", 0))
revenues.append(episode_revenue)
mean_coi = np.mean(coi_levels) if coi_levels else 0.0
if alpha == 0.0:
baseline_coi[name] = mean_coi
base = baseline_coi.get(name, mean_coi)
coi_preserved = mean_coi / base if base > 0 else 1.0
result = ProviderResult(
name=name,
alpha=alpha,
total_revenue=float(np.sum(revenues)),
mean_revenue=float(np.mean(revenues)),
coi_level=mean_coi,
coi_preserved_pct=coi_preserved * 100,
margin_integrity=float(np.mean(margins)) if margins else 0.0,
regret=0.0, # compute vs optimal if known
episodes=self.config.n_episodes,
)
self.results.append(result)
# log to wandb if available
if HAS_WANDB and wandb.run is not None:
wandb.log(
{
f"benchmark/{name}/revenue": result.mean_revenue,
f"benchmark/{name}/coi_preserved": result.coi_preserved_pct,
f"benchmark/{name}/margin": result.margin_integrity,
"benchmark/alpha": alpha,
}
)
return self.results
def to_dataframe(self) -> pd.DataFrame:
"""Convert results to pandas DataFrame."""
return pd.DataFrame([r.__dict__ for r in self.results])
def summary_table(self) -> pd.DataFrame:
"""Pivot table: providers x alpha with revenue/COI metrics."""
df = self.to_dataframe()
return df.pivot_table(
index="name",
columns="alpha",
values=["mean_revenue", "coi_preserved_pct", "margin_integrity"],
aggfunc="mean",
)

77
engine/lib/wrappers.py Normal file
View File

@@ -0,0 +1,77 @@
"""Economic metrics wrapper - calculates thesis-aligned KPIs and injects into info dict."""
import gymnasium as gym
import numpy as np
class EconomicMetricsWrapper(gym.Wrapper):
"""Calculates thesis-aligned economic metrics per step, injects into info.
Metrics follow thesis definitions:
- COI level: E[P] - p_min (Definition 1)
- Margin: (avg_price - p_min) / avg_price
- Regret: 1 - (revenue / baseline_revenue)
"""
def __init__(
self, env: gym.Env, p_min: float = 10.0, baseline_revenue: float | None = None
):
super().__init__(env)
self.p_min = p_min
self.baseline_revenue = baseline_revenue
self._price_history: list[np.ndarray] = []
self._revenue_history: list[float] = []
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self._price_history = []
self._revenue_history = []
return obs, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
# extract from unwrapped env
prices = self.env.unwrapped._prices
demand_dict = self.env.unwrapped._demand
demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))])
alpha = self.env.unwrapped.alpha
# core calculations
revenue = float(np.sum(prices * demand))
avg_price = float(np.mean(prices))
margin = (avg_price - self.p_min) / max(avg_price, 1e-6)
coi_level = avg_price - self.p_min # E[P] - p_min per thesis Def 1
self._price_history.append(prices.copy())
self._revenue_history.append(revenue)
# regret vs baseline (golden path)
regret = 0.0
if self.baseline_revenue and self.baseline_revenue > 0:
regret = 1.0 - (revenue / self.baseline_revenue)
# inject structured metrics into info
info["economics"] = {
"revenue": revenue,
"margin": margin,
"coi_level": coi_level,
"regret": regret,
}
for key in ("coi_mix", "coi_base", "coi_leakage", "coi_penalty"):
if key in info:
info["economics"][key] = info[key]
info["prices"] = prices.copy()
info["demand"] = demand.copy()
return obs, reward, terminated, truncated, info
@property
def episode_revenue(self) -> float:
return sum(self._revenue_history)
@property
def episode_mean_price(self) -> float:
if not self._price_history:
return 0.0
return float(np.mean([np.mean(p) for p in self._price_history]))

View File

@@ -0,0 +1,84 @@
method: random
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
algo:
values: [ppo, a2c, dqn, qtable]
total_timesteps:
values: [30000, 50000, 80000]
seed:
values: [13, 42, 77]
n_products:
values: [8, 10, 12]
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
robust_points:
values: [3, 5, 7]
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
buffer_size:
values: [20000, 50000, 100000]
batch_size:
values: [128, 256, 512]
tau:
values: [0.002, 0.005, 0.01]
train_freq:
values: [1, 4, 8]
learning_starts:
values: [500, 1000, 3000]
n_steps:
values: [512, 1024, 2048]
n_epochs:
values: [5, 10, 20]
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]
target_update_interval:
values: [500, 1000, 2000]
exploration_fraction:
values: [0.1, 0.2, 0.3]
exploration_final_eps:
values: [0.01, 0.03, 0.05]
action_levels:
values: [7, 9, 11]
action_scale_low:
values: [0.75, 0.8, 0.85]
action_scale_high:
values: [1.15, 1.2, 1.25]
q_lr:
values: [0.03, 0.05, 0.1, 0.2]
eps_start:
value: 1.0
eps_end:
values: [0.02, 0.05, 0.1]
eps_decay:
values: [0.999, 0.9995, 0.9999]

View File

@@ -0,0 +1,85 @@
method: grid
metric:
name: sweep/score
goal: maximize
run_cap: 4
command:
- ${env}
- python
- -m
- engine.train
parameters:
algo:
values: [ppo, a2c, dqn, qtable]
seed:
value: 42
total_timesteps:
value: 12000
eval_episodes:
value: 3
eval_freq:
value: 500
log_freq:
value: 100
revenue_weight:
value: 0.01
n_products:
value: 8
N:
value: 80
alpha:
value: 0.3
lambda_coi:
value: 0.2
robust_radius:
value: 0.0
robust_points:
value: 1
info_value:
value: 1.0
learning_rate:
value: 0.0003
gamma:
value: 0.99
buffer_size:
value: 20000
batch_size:
value: 128
tau:
value: 0.005
train_freq:
value: 1
learning_starts:
value: 500
n_steps:
value: 512
n_epochs:
value: 10
gae_lambda:
value: 0.95
clip_range:
value: 0.2
ent_coef:
value: 0.0
target_update_interval:
value: 500
exploration_fraction:
value: 0.2
exploration_final_eps:
value: 0.05
action_levels:
value: 7
action_scale_low:
value: 0.9
action_scale_high:
value: 1.1
q_lr:
value: 0.1
q_bins:
value: 6
eps_start:
value: 1.0
eps_end:
value: 0.05
eps_decay:
value: 0.9995

View File

@@ -0,0 +1,54 @@
method: bayes
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
algo:
value: sac
total_timesteps:
values: [50000, 80000, 120000]
seed:
values: [13, 42, 77]
alpha:
distribution: uniform
min: 0.15
max: 0.55
n_products:
values: [8, 10, 12]
lambda_coi:
distribution: uniform
min: 0.05
max: 0.5
robust_radius:
distribution: uniform
min: 0.05
max: 0.3
robust_points:
values: [3, 5, 7]
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
learning_rate:
distribution: log_uniform_values
min: 3.0e-5
max: 1.0e-3
gamma:
values: [0.98, 0.99, 0.995]
buffer_size:
values: [50000, 100000, 200000]
batch_size:
values: [128, 256, 512]
tau:
values: [0.002, 0.005, 0.01]
train_freq:
values: [1, 4, 8]
learning_starts:
values: [1000, 3000, 5000]

View File

@@ -0,0 +1,86 @@
method: random
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
algo:
values: [ppo, a2c, dqn, qtable]
arch:
values: [tiny, small, medium]
activation:
values: [relu, tanh]
total_timesteps:
values: [8000, 12000, 20000]
seed:
values: [13, 42, 77]
n_products:
values: [6, 8, 10]
alpha:
distribution: uniform
min: 0.1
max: 0.5
lambda_coi:
distribution: uniform
min: 0.05
max: 0.4
robust_radius:
values: [0.0, 0.1, 0.2]
robust_points:
values: [3, 5]
info_value:
values: [0.75, 1.0, 1.5]
revenue_weight:
values: [0.005, 0.01, 0.02]
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 5.0e-4
gamma:
values: [0.98, 0.99]
buffer_size:
values: [10000, 30000, 50000]
batch_size:
values: [64, 128, 256]
tau:
values: [0.002, 0.005, 0.01]
train_freq:
values: [1, 4]
learning_starts:
values: [500, 1000, 2000]
n_steps:
values: [256, 512, 1024]
n_epochs:
values: [5, 10]
gae_lambda:
values: [0.9, 0.95]
clip_range:
values: [0.1, 0.2]
ent_coef:
values: [0.0, 0.005]
target_update_interval:
values: [500, 1000]
exploration_fraction:
values: [0.1, 0.2]
exploration_final_eps:
values: [0.02, 0.05]
action_levels:
values: [5, 7, 9]
action_scale_low:
values: [0.85, 0.9]
action_scale_high:
values: [1.1, 1.15]
q_lr:
values: [0.05, 0.1, 0.2]
q_bins:
values: [4, 6, 8]
eps_start:
value: 1.0
eps_end:
values: [0.02, 0.05]
eps_decay:
values: [0.999, 0.9995]

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import argparse import argparse
import json import json
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from gymnasium.wrappers import FlattenObservation
try: try:
import wandb import wandb
@@ -20,9 +22,7 @@ try:
except ImportError: except ImportError:
HAS_SB3 = False HAS_SB3 = False
from .wrapper import PHANTOM from .jax import JAX_AVAILABLE
from .lib import EconomicMetricsWrapper, MetricsCallback
from .lib.discrete import EventQTable
DEFAULT_CFG = { DEFAULT_CFG = {
@@ -69,14 +69,34 @@ DEFAULT_CFG = {
"arch": "small", "arch": "small",
"activation": "relu", "activation": "relu",
"q_bins": 6, "q_bins": 6,
"max_steps": 100,
"margin_floor": 0.05,
"margin_floor_patience": 5,
"use_jax": False,
"jax_num_envs": 16,
"jax_num_steps": 128,
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
} }
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _cfg(raw: dict | None = None) -> dict: def _cfg(raw: dict | None = None) -> dict:
cfg = dict(DEFAULT_CFG) cfg = dict(DEFAULT_CFG)
if raw: if raw:
cfg.update({k: v for k, v in raw.items() if v is not None}) cfg.update({k: v for k, v in raw.items() if v is not None})
cfg["algo"] = str(cfg["algo"]).lower() cfg["algo"] = str(cfg["algo"]).lower()
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
os.environ.get("PHANTOM_USE_JAX")
)
return cfg return cfg
@@ -89,6 +109,11 @@ def _wandb_cfg_dict() -> dict:
def make_env(cfg: dict): def make_env(cfg: dict):
from gymnasium.wrappers import FlattenObservation
from .wrapper import PHANTOM
from .lib.wrappers import EconomicMetricsWrapper
env = PHANTOM( env = PHANTOM(
n_products=int(cfg["n_products"]), n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]), alpha=float(cfg["alpha"]),
@@ -101,6 +126,9 @@ def make_env(cfg: dict):
action_levels=int(cfg["action_levels"]), action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]), action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]), action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None, render_mode=None,
) )
env = EconomicMetricsWrapper(env) env = EconomicMetricsWrapper(env)
@@ -235,6 +263,8 @@ def build_model(cfg: dict, env):
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
from .lib.discrete import EventQTable
np.random.seed(int(cfg["seed"])) np.random.seed(int(cfg["seed"]))
env = make_env(cfg) env = make_env(cfg)
eval_env = make_env(cfg) eval_env = make_env(cfg)
@@ -275,6 +305,8 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
def train_sb3(cfg: dict) -> tuple[object, dict]: def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3: if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models") raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import MetricsCallback
env = make_env(cfg) env = make_env(cfg)
eval_env = make_env(cfg) eval_env = make_env(cfg)
env = Monitor(env) env = Monitor(env)
@@ -303,7 +335,20 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
def train_once(cfg: dict) -> dict: def train_once(cfg: dict) -> dict:
algo = cfg["algo"] algo = cfg["algo"]
if cfg.get("use_jax"):
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
if algo == "qtable": if algo == "qtable":
raise ValueError("qtable is not supported in JAX backend")
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
_, metrics = train_jax(cfg)
elif algo == "qtable":
_, metrics = train_qtable(cfg) _, metrics = train_qtable(cfg)
else: else:
_, metrics = train_sb3(cfg) _, metrics = train_sb3(cfg)
@@ -357,8 +402,17 @@ def main():
p.add_argument("--learning-rate", type=float) p.add_argument("--learning-rate", type=float)
p.add_argument("--gamma", type=float) p.add_argument("--gamma", type=float)
p.add_argument("--revenue-weight", type=float) p.add_argument("--revenue-weight", type=float)
p.add_argument("--max-steps", type=int)
p.add_argument("--margin-floor", type=float)
p.add_argument("--margin-floor-patience", type=int)
p.add_argument("--arch", type=str) p.add_argument("--arch", type=str)
p.add_argument("--activation", type=str) p.add_argument("--activation", type=str)
p.add_argument("--jax", action="store_true")
p.add_argument("--jax-num-envs", type=int)
p.add_argument("--jax-num-steps", type=int)
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--sweep-agent", action="store_true") p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str) p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0) p.add_argument("--count", type=int, default=0)
@@ -377,8 +431,19 @@ def main():
"learning_rate": args.learning_rate, "learning_rate": args.learning_rate,
"gamma": args.gamma, "gamma": args.gamma,
"revenue_weight": args.revenue_weight, "revenue_weight": args.revenue_weight,
"max_steps": args.max_steps,
"margin_floor": args.margin_floor,
"margin_floor_patience": args.margin_floor_patience,
"arch": args.arch, "arch": args.arch,
"activation": args.activation, "activation": args.activation,
"use_jax": args.jax,
"jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,
} }
overrides = {k: v for k, v in overrides.items() if v is not None} overrides = {k: v for k, v in overrides.items() if v is not None}