Files
PHANTOM/sim/rl/jax_core/simulation.py

117 lines
5.7 KiB
Python

"""Vectorized Markov chain session sampling with JAX."""
from typing import NamedTuple, Tuple
import numpy as np
from functools import partial
try:
import jax, jax.numpy as jnp
from jax import lax
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
from .transitions import TransitionData, N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX
class SessionBatch(NamedTuple):
states: np.ndarray # (n_sess, max_len) state indices, -1=padding
dwells: np.ndarray # (n_sess, max_len) dwell times
products: np.ndarray # (n_sess,) product index per session
actors: np.ndarray # (n_sess,) 0=human, 1=agent
lengths: np.ndarray # (n_sess,) actual session length
class SimResult(NamedTuple):
demand_human: np.ndarray
demand_agent: np.ndarray
revenue: float
revenue_oracle: float
agent_loss: float
coi: float
look_to_book: float
mean_sale_price: float
n_human_purchases: int
n_agent_purchases: int
sessions: SessionBatch
if JAX_AVAILABLE:
@partial(jax.jit, static_argnums=(5,6,7))
def _sample_sessions_jax(key, T_human, T_agent, dwell_human, dwell_agent, n_human, n_agent, max_steps):
n = n_human + n_agent
k1, k2, k3, k4 = jax.random.split(key, 4)
actors = jnp.concatenate([jnp.zeros(n_human, dtype=jnp.int32), jnp.ones(n_agent, dtype=jnp.int32)])
T = jnp.where(actors[:,None,None]==0, T_human[None], T_agent[None]) # (n,6,6)
dwell_p = jnp.where(actors[:,None,None]==0, dwell_human[None], dwell_agent[None]) # (n,6,2)
def step(carry, _):
s, active, k = carry
k, k1, k2 = jax.random.split(k, 3)
probs = T[jnp.arange(n), s] # (n,6)
nxt = jax.random.categorical(k1, jnp.log(probs + 1e-10))
nxt = jnp.where(active, nxt, -1)
shape = dwell_p[jnp.arange(n), s, 0]
scale = dwell_p[jnp.arange(n), s, 1]
dwell = jnp.maximum(0.3, jax.random.gamma(k2, shape) * scale)
still = active & (nxt != TERM_IDX) & (nxt >= 0)
return (nxt, still, k), (nxt, dwell)
init = (jnp.zeros(n, dtype=jnp.int32), jnp.ones(n, dtype=jnp.bool_), k3)
_, (states, dwells) = lax.scan(step, init, None, length=max_steps)
states, dwells = states.T, dwells.T # (n, max_steps)
is_term = (states == -1) | (states == TERM_IDX)
lengths = jnp.argmax(is_term, axis=1) + 1
lengths = jnp.where(jnp.any(is_term, axis=1), lengths, max_steps)
return states, dwells, actors, lengths
def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_products: int, max_steps: int = 40) -> SessionBatch:
if JAX_AVAILABLE:
k1, k2 = jax.random.split(key)
states, dwells, actors, lengths = _sample_sessions_jax(k1, trans.human_T, trans.agent_T, trans.human_dwell, trans.agent_dwell, n_human, n_agent, max_steps)
products = jax.random.randint(k2, (n_human + n_agent,), 0, n_products)
return SessionBatch(np.asarray(states), np.asarray(dwells), np.asarray(products), np.asarray(actors), np.asarray(lengths))
# numpy fallback
rng = np.random.default_rng(int(key[0]) if hasattr(key, '__getitem__') else 42)
n = n_human + n_agent
actors = np.concatenate([np.zeros(n_human, dtype=np.int32), np.ones(n_agent, dtype=np.int32)])
products = rng.integers(0, n_products, size=n)
states, dwells = np.full((n, max_steps), -1, dtype=np.int32), np.zeros((n, max_steps), dtype=np.float32)
lengths = np.zeros(n, dtype=np.int32)
for i in range(n):
T = trans.human_T if actors[i] == 0 else trans.agent_T
dp = trans.human_dwell if actors[i] == 0 else trans.agent_dwell
s, t = 0, 0
while t < max_steps and s != TERM_IDX:
states[i, t] = s
dwells[i, t] = max(0.3, rng.gamma(dp[s, 0], dp[s, 1]))
s = rng.choice(N_STATES, p=T[s])
t += 1
lengths[i] = t
return SessionBatch(states, dwells, products, actors, lengths)
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult:
purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
human_mask, agent_mask = batch.actors == 0, batch.actors == 1
human_purch, agent_purch = purchased & human_mask, purchased & agent_mask
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
# revenue and oracle
purch_products = batch.products[purchased]
revenue = float(np.sum(prices[purch_products]))
revenue_oracle = float(np.sum(base_price[purch_products]))
# agent loss: base_price - price_paid for agent purchases (agents gaming the system)
agent_products = batch.products[agent_purch]
agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products]))
# COI: margin - expected_premium*0.5 for human purchases
human_products = batch.products[human_purch]
if len(human_products) > 0:
margin = float(np.mean(prices[human_products] - unit_cost[human_products]))
premium = float(np.mean(base_price[human_products] - prices[human_products]))
coi = max(0.0, margin - premium * 0.5)
else:
coi = 0.0
# look to book: views / purchases
views = float(np.sum(batch.states == 1)) # view_item_page = index 1
n_purch = int(purchased.sum())
look_to_book = views / (n_purch + 1e-6)
mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0
return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale,
int(human_purch.sum()), int(agent_purch.sum()), batch)