"""Vectorized Markov chain session sampling with JAX.""" from typing import NamedTuple, Tuple import numpy as np from functools import partial try: import jax, jax.numpy as jnp from jax import lax JAX_AVAILABLE = True except ImportError: JAX_AVAILABLE = False from .transitions import TransitionData, N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX class SessionBatch(NamedTuple): states: np.ndarray # (n_sess, max_len) state indices, -1=padding dwells: np.ndarray # (n_sess, max_len) dwell times products: np.ndarray # (n_sess,) product index per session actors: np.ndarray # (n_sess,) 0=human, 1=agent lengths: np.ndarray # (n_sess,) actual session length class SimResult(NamedTuple): demand_human: np.ndarray demand_agent: np.ndarray revenue: float revenue_oracle: float agent_loss: float coi: float look_to_book: float mean_sale_price: float n_human_purchases: int n_agent_purchases: int sessions: SessionBatch if JAX_AVAILABLE: @partial(jax.jit, static_argnums=(5,6,7)) def _sample_sessions_jax(key, T_human, T_agent, dwell_human, dwell_agent, n_human, n_agent, max_steps): n = n_human + n_agent k1, k2, k3, k4 = jax.random.split(key, 4) actors = jnp.concatenate([jnp.zeros(n_human, dtype=jnp.int32), jnp.ones(n_agent, dtype=jnp.int32)]) T = jnp.where(actors[:,None,None]==0, T_human[None], T_agent[None]) # (n,6,6) dwell_p = jnp.where(actors[:,None,None]==0, dwell_human[None], dwell_agent[None]) # (n,6,2) def step(carry, _): s, active, k = carry k, k1, k2 = jax.random.split(k, 3) probs = T[jnp.arange(n), s] # (n,6) nxt = jax.random.categorical(k1, jnp.log(probs + 1e-10)) nxt = jnp.where(active, nxt, -1) shape = dwell_p[jnp.arange(n), s, 0] scale = dwell_p[jnp.arange(n), s, 1] dwell = jnp.maximum(0.3, jax.random.gamma(k2, shape) * scale) still = active & (nxt != TERM_IDX) & (nxt >= 0) return (nxt, still, k), (nxt, dwell) init = (jnp.zeros(n, dtype=jnp.int32), jnp.ones(n, dtype=jnp.bool_), k3) _, (states, dwells) = lax.scan(step, init, None, length=max_steps) states, dwells = states.T, dwells.T # (n, max_steps) is_term = (states == -1) | (states == TERM_IDX) lengths = jnp.argmax(is_term, axis=1) + 1 lengths = jnp.where(jnp.any(is_term, axis=1), lengths, max_steps) return states, dwells, actors, lengths def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_products: int, max_steps: int = 40) -> SessionBatch: if JAX_AVAILABLE: k1, k2 = jax.random.split(key) states, dwells, actors, lengths = _sample_sessions_jax(k1, trans.human_T, trans.agent_T, trans.human_dwell, trans.agent_dwell, n_human, n_agent, max_steps) products = jax.random.randint(k2, (n_human + n_agent,), 0, n_products) return SessionBatch(np.asarray(states), np.asarray(dwells), np.asarray(products), np.asarray(actors), np.asarray(lengths)) # numpy fallback rng = np.random.default_rng(int(key[0]) if hasattr(key, '__getitem__') else 42) n = n_human + n_agent actors = np.concatenate([np.zeros(n_human, dtype=np.int32), np.ones(n_agent, dtype=np.int32)]) products = rng.integers(0, n_products, size=n) states, dwells = np.full((n, max_steps), -1, dtype=np.int32), np.zeros((n, max_steps), dtype=np.float32) lengths = np.zeros(n, dtype=np.int32) for i in range(n): T = trans.human_T if actors[i] == 0 else trans.agent_T dp = trans.human_dwell if actors[i] == 0 else trans.agent_dwell s, t = 0, 0 while t < max_steps and s != TERM_IDX: states[i, t] = s dwells[i, t] = max(0.3, rng.gamma(dp[s, 0], dp[s, 1])) s = rng.choice(N_STATES, p=T[s]) t += 1 lengths[i] = t return SessionBatch(states, dwells, products, actors, lengths) def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult: purchased = np.any(batch.states == PURCHASE_IDX, axis=1) human_mask, agent_mask = batch.actors == 0, batch.actors == 1 human_purch, agent_purch = purchased & human_mask, purchased & agent_mask demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32) demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32) # revenue and oracle purch_products = batch.products[purchased] revenue = float(np.sum(prices[purch_products])) revenue_oracle = float(np.sum(base_price[purch_products])) # agent loss: base_price - price_paid for agent purchases (agents gaming the system) agent_products = batch.products[agent_purch] agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products])) # COI: margin - expected_premium*0.5 for human purchases human_products = batch.products[human_purch] if len(human_products) > 0: margin = float(np.mean(prices[human_products] - unit_cost[human_products])) premium = float(np.mean(base_price[human_products] - prices[human_products])) coi = max(0.0, margin - premium * 0.5) else: coi = 0.0 # look to book: views / purchases views = float(np.sum(batch.states == 1)) # view_item_page = index 1 n_purch = int(purchased.sum()) look_to_book = views / (n_purch + 1e-6) mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0 return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale, int(human_purch.sum()), int(agent_purch.sum()), batch)