From a217d53556fbad0dab100eb5c8d601b63c971803 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 22 Jan 2026 13:10:01 +0100 Subject: [PATCH] feat: translating features to jax --- sim/rl/jax_core/features.py | 69 +++++++++++++++++++++++++ sim/rl/jax_core/separability.py | 43 +++++++++++++++ sim/rl/jax_core/simulation.py | 92 +++++++++++++++++++++++++++++++++ sim/rl/jax_core/transitions.py | 47 +++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 sim/rl/jax_core/features.py create mode 100644 sim/rl/jax_core/separability.py create mode 100644 sim/rl/jax_core/simulation.py create mode 100644 sim/rl/jax_core/transitions.py diff --git a/sim/rl/jax_core/features.py b/sim/rl/jax_core/features.py new file mode 100644 index 0000000..d5af957 --- /dev/null +++ b/sim/rl/jax_core/features.py @@ -0,0 +1,69 @@ +"""Vectorized session feature extraction.""" +import numpy as np +from .transitions import N_STATES, PURCHASE_IDX, CART_IDX +from .simulation import SessionBatch + +try: + import jax.numpy as jnp + from jax import jit + JAX_AVAILABLE = True +except ImportError: + jnp, JAX_AVAILABLE = np, False + def jit(f): return f + +@jit +def extract_features(states, dwells, lengths): + """Extract per-session features. Returns (n_sess, 9) array.""" + n, max_len = states.shape + mask = jnp.arange(max_len)[None,:] < lengths[:,None] + duration = jnp.sum(dwells * mask, axis=1) + total = lengths.astype(jnp.float32) + count = lambda idx: jnp.sum((states == idx) & mask, axis=1).astype(jnp.float32) + views, learn, carts, purchases = count(1), count(2), count(3), count(4) + velocity = total / (duration + 1e-6) + conversion = purchases / (views + 1e-6) + avg_dwell = duration / (total + 1e-6) + return jnp.stack([duration, avg_dwell, total, velocity, views, carts, purchases, learn, conversion], axis=1) + +def session_features(batch: SessionBatch) -> np.ndarray: + if JAX_AVAILABLE: + return np.asarray(extract_features(jnp.array(batch.states), jnp.array(batch.dwells), jnp.array(batch.lengths))) + # numpy fallback + n, max_len = batch.states.shape + mask = np.arange(max_len)[None,:] < batch.lengths[:,None] + duration = np.sum(batch.dwells * mask, axis=1) + total = batch.lengths.astype(np.float32) + count = lambda idx: np.sum((batch.states == idx) & mask, axis=1).astype(np.float32) + views, learn, carts, purchases = count(1), count(2), count(3), count(4) + return np.stack([duration, duration/(total+1e-6), total, total/(duration+1e-6), views, carts, purchases, learn, purchases/(views+1e-6)], axis=1) + +@jit +def session_transitions(states, lengths, n_states=N_STATES): + """Compute empirical transition counts per session. Returns (n_sess, n_states, n_states).""" + n, max_len = states.shape + mask = jnp.arange(max_len - 1)[None,:] < (lengths[:,None] - 1) + src, dst = states[:, :-1], states[:, 1:] + # handle -1 padding by clamping to valid range + src_c, dst_c = jnp.clip(src, 0, n_states-1), jnp.clip(dst, 0, n_states-1) + valid = mask & (src >= 0) & (dst >= 0) + def per_session(i): + s, d, v = src_c[i], dst_c[i], valid[i] + trans = (jnp.eye(n_states)[s,:,None] * jnp.eye(n_states)[d,None,:]).sum(0) * v[:,None,None] + return trans.sum(0) + # vmap not ideal here, use manual loop for clarity + trans = jnp.stack([per_session(i) for i in range(n)]) + row_sums = trans.sum(axis=-1, keepdims=True) + return trans / (row_sums + 1e-10) + +def compute_session_transitions(batch: SessionBatch) -> np.ndarray: + if JAX_AVAILABLE: + return np.asarray(session_transitions(jnp.array(batch.states), jnp.array(batch.lengths))) + # numpy fallback + n, max_len = batch.states.shape + trans = np.zeros((n, N_STATES, N_STATES), dtype=np.float32) + for i in range(n): + for t in range(batch.lengths[i] - 1): + s, d = batch.states[i, t], batch.states[i, t+1] + if s >= 0 and d >= 0: trans[i, s, d] += 1 + row_sums = trans.sum(axis=-1, keepdims=True) + return trans / (row_sums + 1e-10) diff --git a/sim/rl/jax_core/separability.py b/sim/rl/jax_core/separability.py new file mode 100644 index 0000000..c0c0293 --- /dev/null +++ b/sim/rl/jax_core/separability.py @@ -0,0 +1,43 @@ +"""Vectorized KL divergence for separability scoring.""" +import numpy as np +from typing import Tuple + +try: + import jax.numpy as jnp + from jax import jit + JAX_AVAILABLE = True +except ImportError: + jnp, JAX_AVAILABLE = np, False + def jit(f): return f + +@jit +def batch_kl(P, Q_human, Q_agent, eps=1e-10): + """Compute KL(P||Q) for batched P. P:(n,s,s), Q:(s,s). Returns (delta_h, delta_a) each (n,).""" + p = P + eps + p = p / p.sum(axis=-1, keepdims=True) + qh, qa = Q_human[None] + eps, Q_agent[None] + eps + delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2)) + delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2)) + return delta_h, delta_a + +def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Compute KL divergence of each session from human/agent prototypes.""" + if JAX_AVAILABLE: + dh, da = batch_kl(jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent)) + return np.asarray(dh), np.asarray(da) + # numpy fallback + eps = 1e-10 + p = session_trans + eps + p = p / p.sum(axis=-1, keepdims=True) + qh, qa = ref_human[None] + eps, ref_agent[None] + eps + delta_h = np.sum(p * np.log(p / qh), axis=(1, 2)) + delta_a = np.sum(p * np.log(p / qa), axis=(1, 2)) + return delta_h, delta_a + +def estimate_alpha_batch(prob_agent: np.ndarray, delta_h: np.ndarray, delta_a: np.ndarray, temp: float = 1.0) -> np.ndarray: + """Vectorized alpha estimation from classifier probs and divergences.""" + mass = delta_h + delta_a + ratio = np.where(mass > 1e-8, delta_a / mass, 0.5) + blended = 0.5 * prob_agent + 0.5 * ratio + if temp <= 0: return np.clip(blended, 0.0, 1.0) + return np.clip(1.0 / (1.0 + np.exp(-temp * (blended - 0.5))), 0.0, 1.0) diff --git a/sim/rl/jax_core/simulation.py b/sim/rl/jax_core/simulation.py new file mode 100644 index 0000000..ee8ca6f --- /dev/null +++ b/sim/rl/jax_core/simulation.py @@ -0,0 +1,92 @@ +"""Vectorized Markov chain session sampling with JAX.""" +from typing import NamedTuple, Tuple +import numpy as np +from functools import partial + +try: + import jax, jax.numpy as jnp + from jax import lax + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + +from .transitions import TransitionData, N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX + +class SessionBatch(NamedTuple): + states: np.ndarray # (n_sess, max_len) state indices, -1=padding + dwells: np.ndarray # (n_sess, max_len) dwell times + products: np.ndarray # (n_sess,) product index per session + actors: np.ndarray # (n_sess,) 0=human, 1=agent + lengths: np.ndarray # (n_sess,) actual session length + +class SimResult(NamedTuple): + demand_human: np.ndarray + demand_agent: np.ndarray + revenue: float + n_human_purchases: int + n_agent_purchases: int + sessions: SessionBatch + +if JAX_AVAILABLE: + @partial(jax.jit, static_argnums=(5,6,7)) + def _sample_sessions_jax(key, T_human, T_agent, dwell_human, dwell_agent, n_human, n_agent, max_steps): + n = n_human + n_agent + k1, k2, k3, k4 = jax.random.split(key, 4) + actors = jnp.concatenate([jnp.zeros(n_human, dtype=jnp.int32), jnp.ones(n_agent, dtype=jnp.int32)]) + T = jnp.where(actors[:,None,None]==0, T_human[None], T_agent[None]) # (n,6,6) + dwell_p = jnp.where(actors[:,None,None]==0, dwell_human[None], dwell_agent[None]) # (n,6,2) + + def step(carry, _): + s, active, k = carry + k, k1, k2 = jax.random.split(k, 3) + probs = T[jnp.arange(n), s] # (n,6) + nxt = jax.random.categorical(k1, jnp.log(probs + 1e-10)) + nxt = jnp.where(active, nxt, -1) + shape = dwell_p[jnp.arange(n), s, 0] + scale = dwell_p[jnp.arange(n), s, 1] + dwell = jnp.maximum(0.3, jax.random.gamma(k2, shape) * scale) + still = active & (nxt != TERM_IDX) & (nxt >= 0) + return (nxt, still, k), (nxt, dwell) + + init = (jnp.zeros(n, dtype=jnp.int32), jnp.ones(n, dtype=jnp.bool_), k3) + _, (states, dwells) = lax.scan(step, init, None, length=max_steps) + states, dwells = states.T, dwells.T # (n, max_steps) + is_term = (states == -1) | (states == TERM_IDX) + lengths = jnp.argmax(is_term, axis=1) + 1 + lengths = jnp.where(jnp.any(is_term, axis=1), lengths, max_steps) + return states, dwells, actors, lengths + +def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_products: int, max_steps: int = 40) -> SessionBatch: + if JAX_AVAILABLE: + k1, k2 = jax.random.split(key) + states, dwells, actors, lengths = _sample_sessions_jax(k1, trans.human_T, trans.agent_T, trans.human_dwell, trans.agent_dwell, n_human, n_agent, max_steps) + products = jax.random.randint(k2, (n_human + n_agent,), 0, n_products) + return SessionBatch(np.asarray(states), np.asarray(dwells), np.asarray(products), np.asarray(actors), np.asarray(lengths)) + # numpy fallback + rng = np.random.default_rng(int(key[0]) if hasattr(key, '__getitem__') else 42) + n = n_human + n_agent + actors = np.concatenate([np.zeros(n_human, dtype=np.int32), np.ones(n_agent, dtype=np.int32)]) + products = rng.integers(0, n_products, size=n) + states, dwells = np.full((n, max_steps), -1, dtype=np.int32), np.zeros((n, max_steps), dtype=np.float32) + lengths = np.zeros(n, dtype=np.int32) + for i in range(n): + T = trans.human_T if actors[i] == 0 else trans.agent_T + dp = trans.human_dwell if actors[i] == 0 else trans.agent_dwell + s, t = 0, 0 + while t < max_steps and s != TERM_IDX: + states[i, t] = s + dwells[i, t] = max(0.3, rng.gamma(dp[s, 0], dp[s, 1])) + s = rng.choice(N_STATES, p=T[s]) + t += 1 + lengths[i] = t + return SessionBatch(states, dwells, products, actors, lengths) + +def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult: + purchased = np.any(batch.states == PURCHASE_IDX, axis=1) + human_mask, agent_mask = batch.actors == 0, batch.actors == 1 + human_purch = purchased & human_mask + agent_purch = purchased & agent_mask + demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32) + demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32) + revenue = float(np.sum(prices[batch.products[purchased]])) + return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch) diff --git a/sim/rl/jax_core/transitions.py b/sim/rl/jax_core/transitions.py new file mode 100644 index 0000000..6aec650 --- /dev/null +++ b/sim/rl/jax_core/transitions.py @@ -0,0 +1,47 @@ +"""Dense transition matrices for JAX Markov chain sampling.""" +from dataclasses import dataclass +import numpy as np + +try: + import jax.numpy as jnp + JAX_AVAILABLE = True +except ImportError: + jnp, JAX_AVAILABLE = np, False + +STATES = ["session_start", "view_item_page", "learn_more_about_item", "add_item_to_cart", "purchase_complete", "session_end"] +S2I = {s: i for i, s in enumerate(STATES)} +N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX = len(STATES), 5, 4, 3 + +@dataclass +class TransitionData: + human_T: np.ndarray # (6,6) transition probs + agent_T: np.ndarray # (6,6) + human_dwell: np.ndarray # (6,2) shape,scale + agent_dwell: np.ndarray # (6,2) + + def to_jax(self): + if not JAX_AVAILABLE: return self + return TransitionData(*[jnp.array(x) for x in [self.human_T, self.agent_T, self.human_dwell, self.agent_dwell]]) + +def dict_to_dense(d): + m = np.zeros((N_STATES, N_STATES), dtype=np.float32) + for src, dsts in d.items(): + if (i := S2I.get(src)) is not None: + for dst, p in dsts.items(): + if (j := S2I.get(dst)) is not None: m[i,j] = p + m /= np.maximum(m.sum(1, keepdims=True), 1e-8) + m[TERM_IDX] = 0; m[TERM_IDX, TERM_IDX] = 1.0 + return m + +def compile_transitions(human_profile, agent_profile): + def dwell_arr(params): return np.array([[params.get(s, (2.0, 1.0)) for s in STATES]], dtype=np.float32).reshape(N_STATES, 2) + return TransitionData(dict_to_dense(human_profile.transitions), dict_to_dense(agent_profile.transitions), + dwell_arr(human_profile.dwell_params), dwell_arr(agent_profile.dwell_params)) + +def fallback_transitions(): + H = {"session_start": {"view_item_page": .85, "session_end": .15}, "view_item_page": {"learn_more_about_item": .4, "add_item_to_cart": .3, "view_item_page": .2, "session_end": .1}, + "learn_more_about_item": {"add_item_to_cart": .5, "view_item_page": .3, "session_end": .2}, "add_item_to_cart": {"purchase_complete": .6, "view_item_page": .25, "session_end": .15}, "purchase_complete": {"session_end": 1.0}} + A = {"session_start": {"view_item_page": .9, "session_end": .1}, "view_item_page": {"learn_more_about_item": .5, "add_item_to_cart": .25, "view_item_page": .15, "session_end": .1}, + "learn_more_about_item": {"add_item_to_cart": .4, "view_item_page": .4, "session_end": .2}, "add_item_to_cart": {"purchase_complete": .5, "view_item_page": .3, "session_end": .2}, "purchase_complete": {"session_end": 1.0}} + dwell = np.full((N_STATES, 2), [2.0, 1.0], dtype=np.float32) + return TransitionData(dict_to_dense(H), dict_to_dense(A), dwell.copy(), dwell.copy())