adding naive jax and libraries and make adjustments

This commit is contained in:
2026-02-17 14:48:18 +01:00
parent 66c4a0cd1d
commit 802f31b4a1
17 changed files with 2331 additions and 6 deletions

13
engine/jax/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""JAX-compatible training and environment modules for PHANTOM."""
from __future__ import annotations
try:
import jax # noqa: F401
import jax.numpy as jnp # noqa: F401
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
__all__ = ["JAX_AVAILABLE"]

49
engine/jax/checkpoint.py Normal file
View File

@@ -0,0 +1,49 @@
"""Orbax checkpoint helpers for JAX training runs."""
from __future__ import annotations
from pathlib import Path
from typing import Any
try:
import orbax.checkpoint as ocp
HAS_ORBAX = True
except ImportError:
HAS_ORBAX = False
def _require_orbax() -> None:
if not HAS_ORBAX:
raise ImportError(
"orbax-checkpoint is required for checkpoint support. "
"Install engine/jax/requirements.txt first."
)
def create_manager(directory: str | Path, max_to_keep: int = 5):
_require_orbax()
root = Path(directory)
root.mkdir(parents=True, exist_ok=True)
options = ocp.CheckpointManagerOptions(
max_to_keep=max(1, int(max_to_keep)), create=True
)
return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options)
def save(manager, *, step: int, payload: Any) -> bool:
_require_orbax()
return bool(manager.save(int(step), payload))
def latest_step(manager) -> int | None:
_require_orbax()
return manager.latest_step()
def restore(manager, *, target: Any, step: int | None = None) -> Any:
_require_orbax()
step_to_restore = manager.latest_step() if step is None else int(step)
if step_to_restore is None:
return target
return manager.restore(step_to_restore, items=target)

287
engine/jax/env.py Normal file
View File

@@ -0,0 +1,287 @@
"""JAX-native PHANTOM environment with robust contamination step."""
from __future__ import annotations
from typing import NamedTuple
try:
import jax
import jax.numpy as jnp
except ImportError as exc: # pragma: no cover
raise ImportError("engine.jax.env requires JAX") from exc
from .primitives import (
_sample_sessions_jax,
agent_probability_from_kl,
batch_kl,
compute_session_transitions,
load_transition_data,
purchase_flags,
reward_with_coi_penalty,
revenue_from_demand,
weighted_demand,
)
class EnvParams(NamedTuple):
n_products: int
n_sessions: int
max_episode_steps: int
max_session_steps: int
price_low: float
price_high: float
lambda_coi: float
info_value: float
robust_radius: float
margin_floor: float
margin_floor_patience: int
action_scales: jax.Array
alpha_nominal: float
alpha_candidates: jax.Array
human_T: jax.Array
agent_T: jax.Array
terminal_mask: jax.Array
purchase_mask: jax.Array
event_weights: jax.Array
start_idx: int
term_idx: int
class EnvState(NamedTuple):
prices: jax.Array
demand: jax.Array
step_count: jax.Array
low_margin_streak: jax.Array
last_agent_prob: jax.Array
last_alpha_adv: jax.Array
class CandidateEval(NamedTuple):
reward: jax.Array
revenue: jax.Array
demand: jax.Array
agent_prob: jax.Array
leakage: jax.Array
discount: jax.Array
n_purchases: jax.Array
n_agents: jax.Array
def make_env_params(
*,
n_products: int,
alpha: float,
n_sessions: int,
lambda_coi: float,
robust_radius: float,
robust_points: int,
info_value: float,
action_levels: int,
action_scale_low: float,
action_scale_high: float,
price_low: float,
price_high: float,
max_episode_steps: int,
max_session_steps: int = 40,
margin_floor: float = 0.05,
margin_floor_patience: int = 5,
prefer_behavior_data: bool = True,
) -> EnvParams:
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
if robust_radius <= 0.0 or robust_points <= 1:
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
else:
lo = max(0.0, float(alpha) - float(robust_radius))
hi = min(1.0, float(alpha) + float(robust_radius))
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
action_scales = jnp.linspace(
float(action_scale_low),
float(action_scale_high),
int(action_levels),
dtype=jnp.float32,
)
return EnvParams(
n_products=int(n_products),
n_sessions=int(n_sessions),
max_episode_steps=int(max_episode_steps),
max_session_steps=int(max_session_steps),
price_low=float(price_low),
price_high=float(price_high),
lambda_coi=float(lambda_coi),
info_value=float(info_value),
robust_radius=float(robust_radius),
margin_floor=float(margin_floor),
margin_floor_patience=int(margin_floor_patience),
action_scales=action_scales,
alpha_nominal=float(alpha),
alpha_candidates=alpha_candidates,
human_T=jnp.asarray(transition.human_T),
agent_T=jnp.asarray(transition.agent_T),
terminal_mask=jnp.asarray(transition.terminal_mask),
purchase_mask=jnp.asarray(transition.purchase_mask),
event_weights=jnp.asarray(transition.event_weights),
start_idx=int(transition.start_idx),
term_idx=int(transition.term_idx),
)
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
def _decode_action(
prices: jax.Array, action: jax.Array, params: EnvParams
) -> jax.Array:
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
scale = params.action_scales[idx]
next_prices = prices * scale
return jnp.clip(next_prices, params.price_low, params.price_high)
def _evaluate_candidate(
key: jax.Array,
alpha_candidate: jax.Array,
prices: jax.Array,
params: EnvParams,
) -> CandidateEval:
states, products, actors, lengths = _sample_sessions_jax(
key,
params.human_T,
params.agent_T,
params.terminal_mask,
params.start_idx,
params.term_idx,
alpha_candidate,
params.n_products,
params.n_sessions,
params.max_session_steps,
int(params.human_T.shape[0]),
)
session_trans = compute_session_transitions(
states, lengths, int(params.human_T.shape[0])
)
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
agent_probs = agent_probability_from_kl(delta_h, delta_a)
agent_prob = jnp.mean(agent_probs)
demand = weighted_demand(states, products, params.n_products, params.event_weights)
revenue = revenue_from_demand(prices, demand)
reward, leakage, discount = reward_with_coi_penalty(
revenue,
agent_prob,
params.lambda_coi,
params.info_value,
)
purchases = purchase_flags(states, params.purchase_mask)
return CandidateEval(
reward=reward,
revenue=revenue,
demand=demand,
agent_prob=agent_prob,
leakage=leakage,
discount=discount,
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
n_agents=jnp.sum(actors.astype(jnp.float32)),
)
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
prices = jax.random.uniform(
key,
shape=(params.n_products,),
minval=params.price_low,
maxval=params.price_high,
)
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
state = EnvState(
prices=prices,
demand=demand,
step_count=jnp.asarray(0, dtype=jnp.int32),
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
)
return _flatten_obs(demand, prices), state
def step_env(
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams,
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
prices = _decode_action(state.prices, action, params)
n_candidates = params.alpha_candidates.shape[0]
cand_keys = jax.random.split(key, n_candidates)
evals = jax.vmap(
lambda k, a: _evaluate_candidate(k, a, prices, params),
in_axes=(0, 0),
)(cand_keys, params.alpha_candidates)
idx = jnp.argmin(evals.reward)
demand = evals.demand[idx]
reward = evals.reward[idx]
revenue = evals.revenue[idx]
agent_prob = evals.agent_prob[idx]
leakage = evals.leakage[idx]
discount = evals.discount[idx]
n_purchases = evals.n_purchases[idx]
n_agents = evals.n_agents[idx]
alpha_adv = params.alpha_candidates[idx]
step_count = state.step_count + 1
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
avg_margin = (avg_price - params.price_low) / avg_price
next_streak = jnp.where(
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
)
margin_collapsed = next_streak >= params.margin_floor_patience
done = (step_count >= params.max_episode_steps) | margin_collapsed
next_state = EnvState(
prices=prices,
demand=demand,
step_count=step_count,
low_margin_streak=next_streak,
last_agent_prob=agent_prob,
last_alpha_adv=alpha_adv,
)
obs = _flatten_obs(demand, prices)
info = {
"revenue": revenue,
"agent_prob": agent_prob,
"alpha_adv": alpha_adv,
"coi_leakage": leakage,
"coi_discount": discount,
"n_purchases": n_purchases,
"n_agents": n_agents,
"avg_margin": avg_margin,
}
return obs, next_state, reward, done, info
class PHANTOMJAXEnv:
def __init__(self, params: EnvParams):
self.params = params
def reset(self, key: jax.Array, params: EnvParams | None = None):
return reset_env(key, self.params if params is None else params)
def step(
self,
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams | None = None,
):
return step_env(key, state, action, self.params if params is None else params)
def action_space_n(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.action_scales.shape[0])
def observation_dim(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.n_products * 2)

493
engine/jax/primitives.py Normal file
View File

@@ -0,0 +1,493 @@
"""JAX-compatible primitives for PHANTOM session simulation and separability."""
from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import Mapping, Sequence
import numpy as np
try:
import jax
import jax.numpy as jnp
JAX_AVAILABLE = True
except ImportError:
jax = None # type: ignore[assignment]
jnp = np # type: ignore[assignment]
JAX_AVAILABLE = False
STATE_START_KEYS = ("session_start", "start")
TERMINAL_EVENT_TOKENS = (
"session_end",
"end",
"purchase_complete",
"checkout_start",
"checkout",
)
PURCHASE_EVENT_TOKENS = (
"purchase_complete",
"purchase",
"checkout_start",
"checkout",
)
CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5}
ACTION_CATEGORIES = {
"cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"},
"dwell": {
"hover_title",
"hover_paragraph",
"hover_link",
"hover_over_title",
"hover_over_paragraph",
"hover_over_link",
"hover_over_button",
},
"nav": {
"page_view",
"view_item",
"view",
"learn_more",
"learn_more_about_item",
"view_item_page",
"session_start",
},
"filter": {
"search",
"filter_date",
"filter_price",
"sort",
"filter_for_date",
"filter_for_price",
"filter_for_amenities",
"sort_change",
},
}
DEFAULT_ACTION_WEIGHTS = {
action: CATEGORY_WEIGHTS[group]
for group, actions in ACTION_CATEGORIES.items()
for action in actions
}
@dataclass(frozen=True)
class TransitionData:
"""Dense transition kernels and per-state metadata."""
human_T: np.ndarray
agent_T: np.ndarray
terminal_mask: np.ndarray
purchase_mask: np.ndarray
event_weights: np.ndarray
event_names: tuple[str, ...]
start_idx: int
term_idx: int
def to_jax(self) -> "TransitionData":
if not JAX_AVAILABLE:
return self
return TransitionData(
human_T=jnp.asarray(self.human_T),
agent_T=jnp.asarray(self.agent_T),
terminal_mask=jnp.asarray(self.terminal_mask),
purchase_mask=jnp.asarray(self.purchase_mask),
event_weights=jnp.asarray(self.event_weights),
event_names=self.event_names,
start_idx=int(self.start_idx),
term_idx=int(self.term_idx),
)
@dataclass(frozen=True)
class SessionBatch:
states: np.ndarray
products: np.ndarray
actors: np.ndarray
lengths: np.ndarray
def _event_weight(name: str) -> float:
if name in DEFAULT_ACTION_WEIGHTS:
return float(DEFAULT_ACTION_WEIGHTS[name])
if name.startswith("hover"):
return float(CATEGORY_WEIGHTS["dwell"])
if name.startswith("filter") or name in {"search", "sort", "sort_change"}:
return float(CATEGORY_WEIGHTS["filter"])
if name.startswith("add") or name in {
"checkout",
"checkout_start",
"purchase",
"remove_item",
"purchase_complete",
}:
return float(CATEGORY_WEIGHTS["cart"])
if any(token in name for token in TERMINAL_EVENT_TOKENS):
return 0.0
return float(CATEGORY_WEIGHTS["nav"])
def _is_terminal(name: str) -> bool:
return any(token in name for token in TERMINAL_EVENT_TOKENS)
def _is_purchase(name: str) -> bool:
return any(token in name for token in PURCHASE_EVENT_TOKENS)
def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]:
names: set[str] = set()
for trans in transitions:
for src, dsts in trans.items():
names.add(src)
names.update(dsts.keys())
names.discard("__terminal__")
return tuple(sorted(names))
def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray:
row_sums = matrix.sum(axis=1, keepdims=True)
dead_rows = np.isclose(row_sums.squeeze(-1), 0.0)
if np.any(dead_rows):
matrix[dead_rows] = 0.0
matrix[dead_rows, term_idx] = 1.0
row_sums = matrix.sum(axis=1, keepdims=True)
return matrix / np.maximum(row_sums, 1e-8)
def _dense_from_dict(
transitions: Mapping[str, Mapping[str, float]],
event_to_idx: Mapping[str, int],
term_idx: int,
) -> np.ndarray:
n_states = len(event_to_idx)
matrix = np.zeros((n_states, n_states), dtype=np.float32)
for src, dsts in transitions.items():
i = event_to_idx.get(src)
if i is None:
continue
for dst, prob in dsts.items():
j = event_to_idx.get(dst)
if j is None:
continue
matrix[i, j] += float(prob)
return _normalize_rows(matrix, term_idx)
def compile_transition_data(
human_transitions: Mapping[str, Mapping[str, float]],
agent_transitions: Mapping[str, Mapping[str, float]],
) -> TransitionData:
event_names = _collect_events(human_transitions, agent_transitions)
if not event_names:
return fallback_transition_data()
event_names = tuple([*event_names, "__terminal__"])
term_idx = len(event_names) - 1
event_to_idx = {name: i for i, name in enumerate(event_names)}
human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx)
agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx)
terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool)
purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool)
event_weights = np.array(
[_event_weight(name) for name in event_names], dtype=np.float32
)
terminal_mask[term_idx] = True
for idx, is_term in enumerate(terminal_mask):
if not is_term:
continue
human_T[idx] = 0.0
agent_T[idx] = 0.0
human_T[idx, idx] = 1.0
agent_T[idx, idx] = 1.0
start_idx = 0
for key in STATE_START_KEYS:
if key in event_to_idx:
start_idx = int(event_to_idx[key])
break
return TransitionData(
human_T=human_T,
agent_T=agent_T,
terminal_mask=terminal_mask,
purchase_mask=purchase_mask,
event_weights=event_weights,
event_names=event_names,
start_idx=start_idx,
term_idx=term_idx,
)
def fallback_transition_data() -> TransitionData:
human = {
"session_start": {
"page_view": 0.80,
"view_item_page": 0.15,
"session_end": 0.05,
},
"page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20},
"view_item_page": {
"learn_more_about_item": 0.40,
"add_item_to_cart": 0.28,
"session_end": 0.32,
},
"learn_more_about_item": {
"add_item_to_cart": 0.50,
"view_item_page": 0.30,
"session_end": 0.20,
},
"add_item_to_cart": {
"checkout_start": 0.58,
"view_item_page": 0.24,
"session_end": 0.18,
},
"checkout_start": {"purchase_complete": 0.70, "session_end": 0.30},
"purchase_complete": {"session_end": 1.0},
}
agent = {
"session_start": {
"page_view": 0.90,
"view_item_page": 0.08,
"session_end": 0.02,
},
"page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25},
"view_item_page": {
"learn_more_about_item": 0.55,
"add_item_to_cart": 0.15,
"session_end": 0.30,
},
"learn_more_about_item": {
"view_item_page": 0.45,
"add_item_to_cart": 0.20,
"session_end": 0.35,
},
"add_item_to_cart": {
"checkout_start": 0.42,
"view_item_page": 0.28,
"session_end": 0.30,
},
"checkout_start": {"purchase_complete": 0.52, "session_end": 0.48},
"purchase_complete": {"session_end": 1.0},
}
return compile_transition_data(human, agent)
def load_transition_data(prefer_data: bool = True) -> TransitionData:
if not prefer_data:
return fallback_transition_data()
try:
from ..lib.behavior import get_transition_models
human_trans, agent_trans = get_transition_models()
return compile_transition_data(human_trans, agent_trans)
except Exception:
return fallback_transition_data()
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(8, 9, 10))
def _sample_sessions_jax(
key: jax.Array,
human_T: jax.Array,
agent_T: jax.Array,
terminal_mask: jax.Array,
start_idx: int,
term_idx: int,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3)
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint(
k_product, (n_sessions,), 0, n_products, dtype=jnp.int32
)
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
def _scan_step(carry, _):
states, active, rng = carry
rng, k = jax.random.split(rng)
probs_h = human_T[states]
probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, int(term_idx))
emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, int(term_idx))
return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan(
_scan_step, (state_init, active_init, k_step), None, length=max_steps
)
states = state_t.T
lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32)
return states, products, actors, lengths
def sample_sessions(
key,
transition_data: TransitionData,
alpha: float,
n_products: int,
n_sessions: int,
max_steps: int,
) -> SessionBatch:
if JAX_AVAILABLE:
td = transition_data.to_jax()
states, products, actors, lengths = _sample_sessions_jax(
key,
td.human_T,
td.agent_T,
td.terminal_mask,
int(td.start_idx),
int(td.term_idx),
float(alpha),
int(n_products),
int(n_sessions),
int(max_steps),
int(td.human_T.shape[0]),
)
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0]))
n_states = transition_data.human_T.shape[0]
products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32)
actors = (rng.random(size=n_sessions) < alpha).astype(np.int32)
states = np.full((n_sessions, max_steps), -1, dtype=np.int32)
lengths = np.zeros((n_sessions,), dtype=np.int32)
for i in range(n_sessions):
current = int(transition_data.start_idx)
mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T
for t in range(max_steps):
nxt = int(rng.choice(n_states, p=mat[current]))
states[i, t] = nxt
if transition_data.terminal_mask[nxt]:
lengths[i] = t + 1
break
current = nxt
if lengths[i] == 0:
lengths[i] = max_steps
return SessionBatch(
states=states, products=products, actors=actors, lengths=lengths
)
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(2,))
def compute_session_transitions(states, lengths, n_states: int):
src = states[:, :-1]
dst = states[:, 1:]
time_idx = jnp.arange(src.shape[1])[None, :]
valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1))
src_clip = jnp.clip(src, 0, n_states - 1)
dst_clip = jnp.clip(dst, 0, n_states - 1)
src_oh = jax.nn.one_hot(src_clip, n_states)
dst_oh = jax.nn.one_hot(dst_clip, n_states)
counts = jnp.einsum(
"nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32)
)
row_sums = jnp.sum(counts, axis=-1, keepdims=True)
return counts / (row_sums + 1e-10)
else:
def compute_session_transitions(states, lengths, n_states: int):
trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32)
for i in range(states.shape[0]):
for t in range(max(int(lengths[i]) - 1, 0)):
s = int(states[i, t])
d = int(states[i, t + 1])
if s >= 0 and d >= 0:
trans[i, s, d] += 1.0
row_sums = trans.sum(axis=-1, keepdims=True)
return trans / (row_sums + 1e-10)
def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10):
p = P + eps
p = p / jnp.sum(p, axis=-1, keepdims=True)
qh = Q_human[None, ...] + eps
qa = Q_agent[None, ...] + eps
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
return delta_h, delta_a
if JAX_AVAILABLE:
batch_kl = jax.jit(batch_kl)
def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0):
t = jnp.maximum(float(temperature), 1e-6)
exp_h = jnp.exp(-delta_h / t)
exp_a = jnp.exp(-delta_a / t)
return exp_a / (exp_h + exp_a + 1e-10)
def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0):
logits = beta * (delta_h - delta_a)
return 1.0 / (1.0 + jnp.exp(-logits))
def weighted_demand(states, products, n_products: int, event_weights):
valid = states >= 0
state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1)
weights = event_weights[state_clip] * valid
per_session = jnp.sum(weights, axis=1)
demand = jnp.zeros((n_products,), dtype=jnp.float32)
demand = demand.at[products].add(per_session)
total = jnp.sum(demand)
return jnp.where(total > 0.0, (demand / total) * 100.0, demand)
if JAX_AVAILABLE:
weighted_demand = jax.jit(weighted_demand, static_argnums=(2,))
def purchase_flags(states, purchase_mask):
state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1)
hits = purchase_mask[state_clip] & (states >= 0)
return jnp.any(hits, axis=1)
if JAX_AVAILABLE:
purchase_flags = jax.jit(purchase_flags)
def revenue_from_demand(prices, demand):
return jnp.dot(prices, demand)
if JAX_AVAILABLE:
revenue_from_demand = jax.jit(revenue_from_demand)
def reward_with_coi_penalty(
revenue, agent_prob: float, lambda_coi: float, info_value: float
):
leakage = agent_prob * info_value
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
return revenue * discount, leakage, discount
if JAX_AVAILABLE:
reward_with_coi_penalty = jax.jit(reward_with_coi_penalty)

View File

@@ -0,0 +1,5 @@
flax>=0.8.0
optax>=0.2.0
distrax>=0.1.5
orbax-checkpoint>=0.5.0
chex>=0.1.8

471
engine/jax/train.py Normal file
View File

@@ -0,0 +1,471 @@
"""Pure JAX PPO trainer for the PHANTOM environment."""
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
try:
import jax
import jax.numpy as jnp
import distrax
import flax.linen as nn
import optax
from flax import serialization
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
HAS_JAX_STACK = True
except ImportError:
jax = None # type: ignore[assignment]
jnp = None # type: ignore[assignment]
distrax = None # type: ignore[assignment]
optax = None # type: ignore[assignment]
serialization = None # type: ignore[assignment]
class _ModuleStub:
pass
class _NNStub:
Module = _ModuleStub
@staticmethod
def compact(fn):
return fn
nn = _NNStub() # type: ignore[assignment]
def constant(*_args, **_kwargs): # type: ignore[override]
return None
def orthogonal(*_args, **_kwargs): # type: ignore[override]
return None
class TrainState: # type: ignore[override]
pass
HAS_JAX_STACK = False
from .env import PHANTOMJAXEnv, make_env_params
class ActorCritic(nn.Module):
action_dim: int
activation: str = "tanh"
@nn.compact
def __call__(self, x):
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
actor = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(x)
actor = activation_fn(actor)
actor = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(actor)
actor = activation_fn(actor)
logits = nn.Dense(
self.action_dim,
kernel_init=orthogonal(0.01),
bias_init=constant(0.0),
)(actor)
critic = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(x)
critic = activation_fn(critic)
critic = nn.Dense(
64,
kernel_init=orthogonal(np.sqrt(2.0)),
bias_init=constant(0.0),
)(critic)
critic = activation_fn(critic)
value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)
return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1)
class Transition(NamedTuple):
done: jax.Array
action: jax.Array
value: jax.Array
reward: jax.Array
log_prob: jax.Array
obs: jax.Array
info: dict[str, jax.Array]
def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
out = {
"algo": str(cfg.get("algo", "ppo")).lower(),
"seed": int(cfg.get("seed", 42)),
"learning_rate": float(cfg.get("learning_rate", 3e-4)),
"gamma": float(cfg.get("gamma", 0.99)),
"gae_lambda": float(cfg.get("gae_lambda", 0.95)),
"clip_range": float(cfg.get("clip_range", 0.2)),
"ent_coef": float(cfg.get("ent_coef", 0.01)),
"vf_coef": float(cfg.get("vf_coef", 0.5)),
"max_grad_norm": float(cfg.get("max_grad_norm", 0.5)),
"activation": str(cfg.get("activation", "relu")),
"total_timesteps": int(cfg.get("total_timesteps", 50_000)),
"eval_episodes": int(cfg.get("eval_episodes", 5)),
"model_dir": str(cfg.get("model_dir", "engine/models")),
"n_products": int(cfg.get("n_products", 10)),
"N": int(cfg.get("N", 100)),
"alpha": float(cfg.get("alpha", 0.3)),
"lambda_coi": float(cfg.get("lambda_coi", 0.2)),
"robust_radius": float(cfg.get("robust_radius", 0.15)),
"robust_points": int(cfg.get("robust_points", 5)),
"info_value": float(cfg.get("info_value", 1.0)),
"price_low": float(cfg.get("price_low", 10.0)),
"price_high": float(cfg.get("price_high", 150.0)),
"action_levels": int(cfg.get("action_levels", 9)),
"action_scale_low": float(cfg.get("action_scale_low", 0.8)),
"action_scale_high": float(cfg.get("action_scale_high", 1.2)),
"max_episode_steps": int(cfg.get("max_steps", 100)),
"max_session_steps": int(cfg.get("max_session_steps", 40)),
"margin_floor": float(cfg.get("margin_floor", 0.05)),
"margin_floor_patience": int(cfg.get("margin_floor_patience", 5)),
"prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)),
"num_envs": int(cfg.get("jax_num_envs", 16)),
"num_steps": int(cfg.get("jax_num_steps", 128)),
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
}
rollout = out["num_envs"] * out["num_steps"]
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1))
return out
def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array:
mask = done
while mask.ndim < keep.ndim:
mask = mask[..., None]
return jnp.where(mask, reset, keep)
def make_train(config: dict[str, Any]):
cfg = _jax_cfg(config)
env_params = make_env_params(
n_products=cfg["n_products"],
alpha=cfg["alpha"],
n_sessions=cfg["N"],
lambda_coi=cfg["lambda_coi"],
robust_radius=cfg["robust_radius"],
robust_points=cfg["robust_points"],
info_value=cfg["info_value"],
action_levels=cfg["action_levels"],
action_scale_low=cfg["action_scale_low"],
action_scale_high=cfg["action_scale_high"],
price_low=cfg["price_low"],
price_high=cfg["price_high"],
max_episode_steps=cfg["max_episode_steps"],
max_session_steps=cfg["max_session_steps"],
margin_floor=cfg["margin_floor"],
margin_floor_patience=cfg["margin_floor_patience"],
prefer_behavior_data=cfg["prefer_behavior_data"],
)
env = PHANTOMJAXEnv(env_params)
network = ActorCritic(env.action_space_n(), activation=cfg["activation"])
def linear_schedule(count: jax.Array) -> jax.Array:
updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"])
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
return cfg["learning_rate"] * frac
def train(rng: jax.Array):
rng, init_key = jax.random.split(rng)
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
params = network.init(init_key, init_obs)
if cfg["anneal_lr"]:
tx = optax.chain(
optax.clip_by_global_norm(cfg["max_grad_norm"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(cfg["max_grad_norm"]),
optax.adam(cfg["learning_rate"], eps=1e-5),
)
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
rng, reset_key = jax.random.split(rng)
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
obs, env_state = jax.vmap(env.reset)(reset_keys)
def _update_step(runner_state, _):
def _env_step(runner_state, _):
train_state, env_state, last_obs, rng = runner_state
rng, action_key = jax.random.split(rng)
policy, value = network.apply(train_state.params, last_obs)
action = policy.sample(seed=action_key)
log_prob = policy.log_prob(action)
rng, step_key = jax.random.split(rng)
step_keys = jax.random.split(step_key, cfg["num_envs"])
nxt_obs, nxt_state, reward, done, info = jax.vmap(
env.step,
in_axes=(0, 0, 0),
)(step_keys, env_state, action)
rng, reset_key = jax.random.split(rng)
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
env_next = jax.tree_util.tree_map(
lambda keep, reset: _select_env_state(done, keep, reset),
nxt_state,
rst_state,
)
transition = Transition(
done=done,
action=action,
value=value,
reward=reward,
log_prob=log_prob,
obs=last_obs,
info=info,
)
return (train_state, env_next, obs_next, rng), transition
runner_state, traj_batch = jax.lax.scan(
_env_step,
runner_state,
None,
length=cfg["num_steps"],
)
train_state, env_state, last_obs, rng = runner_state
_, last_value = network.apply(train_state.params, last_obs)
def _compute_gae(traj_batch, last_value):
def _gae_step(carry, transition):
gae, next_value = carry
delta = (
transition.reward
+ cfg["gamma"] * next_value * (1.0 - transition.done)
- transition.value
)
gae = (
delta
+ cfg["gamma"]
* cfg["gae_lambda"]
* (1.0 - transition.done)
* gae
)
return (gae, transition.value), gae
_, advantages = jax.lax.scan(
_gae_step,
(jnp.zeros_like(last_value), last_value),
traj_batch,
reverse=True,
unroll=16,
)
targets = advantages + traj_batch.value
return advantages, targets
advantages, targets = _compute_gae(traj_batch, last_value)
def _update_epoch(update_state, _):
def _update_minibatch(train_state, batch_info):
traj_b, adv_b, tgt_b = batch_info
def _loss_fn(params, traj_b, adv_b, tgt_b):
policy, value = network.apply(params, traj_b.obs)
log_prob = policy.log_prob(traj_b.action)
value_clipped = traj_b.value + (value - traj_b.value).clip(
-cfg["clip_range"], cfg["clip_range"]
)
value_loss = (
0.5
* jnp.maximum(
jnp.square(value - tgt_b),
jnp.square(value_clipped - tgt_b),
).mean()
)
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
ratio = jnp.exp(log_prob - traj_b.log_prob)
loss_actor = -jnp.minimum(
ratio * adv_norm,
jnp.clip(
ratio,
1.0 - cfg["clip_range"],
1.0 + cfg["clip_range"],
)
* adv_norm,
).mean()
entropy = policy.entropy().mean()
total_loss = (
loss_actor
+ cfg["vf_coef"] * value_loss
- cfg["ent_coef"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
train_state = train_state.apply_gradients(grads=grads)
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
train_state, traj_batch, advantages, targets, rng = update_state
rng, perm_key = jax.random.split(rng)
batch_size = cfg["num_envs"] * cfg["num_steps"]
permutation = jax.random.permutation(perm_key, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]),
batch,
)
shuffled = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0),
batch,
)
minibatches = jax.tree_util.tree_map(
lambda x: x.reshape(
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
),
shuffled,
)
train_state, _ = jax.lax.scan(
_update_minibatch, train_state, minibatches
)
return (train_state, traj_batch, advantages, targets, rng), None
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, _ = jax.lax.scan(
_update_epoch,
update_state,
None,
length=cfg["update_epochs"],
)
train_state = update_state[0]
rng = update_state[-1]
metric = {
"reward": jnp.mean(traj_batch.reward),
"revenue": jnp.mean(traj_batch.info["revenue"]),
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
}
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
runner_state = (train_state, env_state, obs, rng)
runner_state, metric = jax.lax.scan(
_update_step,
runner_state,
None,
length=cfg["num_updates"],
)
return {
"runner_state": runner_state,
"metrics": metric,
}
return train, network, env, cfg
def evaluate_policy(
*,
network: ActorCritic,
params: Any,
env: PHANTOMJAXEnv,
episodes: int,
seed: int,
) -> dict[str, float]:
rewards: list[float] = []
revenues: list[float] = []
key = jax.random.PRNGKey(seed)
for _ in range(int(episodes)):
key, reset_key = jax.random.split(key)
obs, state = env.reset(reset_key)
ep_reward = 0.0
ep_revenue = 0.0
done = False
steps = 0
while not done and steps < int(env.params.max_episode_steps):
policy, _ = network.apply(params, obs)
action = jnp.argmax(policy.logits)
key, step_key = jax.random.split(key)
obs, state, reward, done_flag, info = env.step(step_key, state, action)
ep_reward += float(np.asarray(reward))
ep_revenue += float(np.asarray(info["revenue"]))
done = bool(np.asarray(done_flag))
steps += 1
rewards.append(ep_reward)
revenues.append(ep_revenue)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
if not HAS_JAX_STACK:
raise ImportError(
"JAX PPO path requires jax, flax, optax, and distrax. "
"Install engine/jax/requirements.txt on this machine first."
)
run_cfg = _jax_cfg(cfg)
if run_cfg["algo"] != "ppo":
raise ValueError(
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
)
train_fn, network, env, run_cfg = make_train(run_cfg)
train_jit = jax.jit(train_fn)
rng = jax.random.PRNGKey(run_cfg["seed"])
out = train_jit(rng)
train_state = out["runner_state"][0]
metric = out["metrics"]
metrics = {
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
"train/global_step": int(
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
),
}
eval_metrics = evaluate_policy(
network=network,
params=train_state.params,
env=env,
episodes=run_cfg["eval_episodes"],
seed=run_cfg["seed"] + 7,
)
metrics.update(eval_metrics)
model_dir = Path(run_cfg["model_dir"])
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / "phantom_ppo_jax.msgpack"
model_path.write_bytes(serialization.to_bytes(train_state.params))
metrics["model/path"] = str(model_path)
return {"params": train_state.params}, metrics