mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
feat: translating features to jax
This commit is contained in:
69
sim/rl/jax_core/features.py
Normal file
69
sim/rl/jax_core/features.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Vectorized session feature extraction."""
|
||||
import numpy as np
|
||||
from .transitions import N_STATES, PURCHASE_IDX, CART_IDX
|
||||
from .simulation import SessionBatch
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
JAX_AVAILABLE = True
|
||||
except ImportError:
|
||||
jnp, JAX_AVAILABLE = np, False
|
||||
def jit(f): return f
|
||||
|
||||
@jit
|
||||
def extract_features(states, dwells, lengths):
|
||||
"""Extract per-session features. Returns (n_sess, 9) array."""
|
||||
n, max_len = states.shape
|
||||
mask = jnp.arange(max_len)[None,:] < lengths[:,None]
|
||||
duration = jnp.sum(dwells * mask, axis=1)
|
||||
total = lengths.astype(jnp.float32)
|
||||
count = lambda idx: jnp.sum((states == idx) & mask, axis=1).astype(jnp.float32)
|
||||
views, learn, carts, purchases = count(1), count(2), count(3), count(4)
|
||||
velocity = total / (duration + 1e-6)
|
||||
conversion = purchases / (views + 1e-6)
|
||||
avg_dwell = duration / (total + 1e-6)
|
||||
return jnp.stack([duration, avg_dwell, total, velocity, views, carts, purchases, learn, conversion], axis=1)
|
||||
|
||||
def session_features(batch: SessionBatch) -> np.ndarray:
|
||||
if JAX_AVAILABLE:
|
||||
return np.asarray(extract_features(jnp.array(batch.states), jnp.array(batch.dwells), jnp.array(batch.lengths)))
|
||||
# numpy fallback
|
||||
n, max_len = batch.states.shape
|
||||
mask = np.arange(max_len)[None,:] < batch.lengths[:,None]
|
||||
duration = np.sum(batch.dwells * mask, axis=1)
|
||||
total = batch.lengths.astype(np.float32)
|
||||
count = lambda idx: np.sum((batch.states == idx) & mask, axis=1).astype(np.float32)
|
||||
views, learn, carts, purchases = count(1), count(2), count(3), count(4)
|
||||
return np.stack([duration, duration/(total+1e-6), total, total/(duration+1e-6), views, carts, purchases, learn, purchases/(views+1e-6)], axis=1)
|
||||
|
||||
@jit
|
||||
def session_transitions(states, lengths, n_states=N_STATES):
|
||||
"""Compute empirical transition counts per session. Returns (n_sess, n_states, n_states)."""
|
||||
n, max_len = states.shape
|
||||
mask = jnp.arange(max_len - 1)[None,:] < (lengths[:,None] - 1)
|
||||
src, dst = states[:, :-1], states[:, 1:]
|
||||
# handle -1 padding by clamping to valid range
|
||||
src_c, dst_c = jnp.clip(src, 0, n_states-1), jnp.clip(dst, 0, n_states-1)
|
||||
valid = mask & (src >= 0) & (dst >= 0)
|
||||
def per_session(i):
|
||||
s, d, v = src_c[i], dst_c[i], valid[i]
|
||||
trans = (jnp.eye(n_states)[s,:,None] * jnp.eye(n_states)[d,None,:]).sum(0) * v[:,None,None]
|
||||
return trans.sum(0)
|
||||
# vmap not ideal here, use manual loop for clarity
|
||||
trans = jnp.stack([per_session(i) for i in range(n)])
|
||||
row_sums = trans.sum(axis=-1, keepdims=True)
|
||||
return trans / (row_sums + 1e-10)
|
||||
|
||||
def compute_session_transitions(batch: SessionBatch) -> np.ndarray:
|
||||
if JAX_AVAILABLE:
|
||||
return np.asarray(session_transitions(jnp.array(batch.states), jnp.array(batch.lengths)))
|
||||
# numpy fallback
|
||||
n, max_len = batch.states.shape
|
||||
trans = np.zeros((n, N_STATES, N_STATES), dtype=np.float32)
|
||||
for i in range(n):
|
||||
for t in range(batch.lengths[i] - 1):
|
||||
s, d = batch.states[i, t], batch.states[i, t+1]
|
||||
if s >= 0 and d >= 0: trans[i, s, d] += 1
|
||||
row_sums = trans.sum(axis=-1, keepdims=True)
|
||||
return trans / (row_sums + 1e-10)
|
||||
43
sim/rl/jax_core/separability.py
Normal file
43
sim/rl/jax_core/separability.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Vectorized KL divergence for separability scoring."""
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
JAX_AVAILABLE = True
|
||||
except ImportError:
|
||||
jnp, JAX_AVAILABLE = np, False
|
||||
def jit(f): return f
|
||||
|
||||
@jit
|
||||
def batch_kl(P, Q_human, Q_agent, eps=1e-10):
|
||||
"""Compute KL(P||Q) for batched P. P:(n,s,s), Q:(s,s). Returns (delta_h, delta_a) each (n,)."""
|
||||
p = P + eps
|
||||
p = p / p.sum(axis=-1, keepdims=True)
|
||||
qh, qa = Q_human[None] + eps, Q_agent[None] + eps
|
||||
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
|
||||
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
|
||||
return delta_h, delta_a
|
||||
|
||||
def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute KL divergence of each session from human/agent prototypes."""
|
||||
if JAX_AVAILABLE:
|
||||
dh, da = batch_kl(jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent))
|
||||
return np.asarray(dh), np.asarray(da)
|
||||
# numpy fallback
|
||||
eps = 1e-10
|
||||
p = session_trans + eps
|
||||
p = p / p.sum(axis=-1, keepdims=True)
|
||||
qh, qa = ref_human[None] + eps, ref_agent[None] + eps
|
||||
delta_h = np.sum(p * np.log(p / qh), axis=(1, 2))
|
||||
delta_a = np.sum(p * np.log(p / qa), axis=(1, 2))
|
||||
return delta_h, delta_a
|
||||
|
||||
def estimate_alpha_batch(prob_agent: np.ndarray, delta_h: np.ndarray, delta_a: np.ndarray, temp: float = 1.0) -> np.ndarray:
|
||||
"""Vectorized alpha estimation from classifier probs and divergences."""
|
||||
mass = delta_h + delta_a
|
||||
ratio = np.where(mass > 1e-8, delta_a / mass, 0.5)
|
||||
blended = 0.5 * prob_agent + 0.5 * ratio
|
||||
if temp <= 0: return np.clip(blended, 0.0, 1.0)
|
||||
return np.clip(1.0 / (1.0 + np.exp(-temp * (blended - 0.5))), 0.0, 1.0)
|
||||
92
sim/rl/jax_core/simulation.py
Normal file
92
sim/rl/jax_core/simulation.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Vectorized Markov chain session sampling with JAX."""
|
||||
from typing import NamedTuple, Tuple
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
import jax, jax.numpy as jnp
|
||||
from jax import lax
|
||||
JAX_AVAILABLE = True
|
||||
except ImportError:
|
||||
JAX_AVAILABLE = False
|
||||
|
||||
from .transitions import TransitionData, N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX
|
||||
|
||||
class SessionBatch(NamedTuple):
|
||||
states: np.ndarray # (n_sess, max_len) state indices, -1=padding
|
||||
dwells: np.ndarray # (n_sess, max_len) dwell times
|
||||
products: np.ndarray # (n_sess,) product index per session
|
||||
actors: np.ndarray # (n_sess,) 0=human, 1=agent
|
||||
lengths: np.ndarray # (n_sess,) actual session length
|
||||
|
||||
class SimResult(NamedTuple):
|
||||
demand_human: np.ndarray
|
||||
demand_agent: np.ndarray
|
||||
revenue: float
|
||||
n_human_purchases: int
|
||||
n_agent_purchases: int
|
||||
sessions: SessionBatch
|
||||
|
||||
if JAX_AVAILABLE:
|
||||
@partial(jax.jit, static_argnums=(5,6,7))
|
||||
def _sample_sessions_jax(key, T_human, T_agent, dwell_human, dwell_agent, n_human, n_agent, max_steps):
|
||||
n = n_human + n_agent
|
||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
||||
actors = jnp.concatenate([jnp.zeros(n_human, dtype=jnp.int32), jnp.ones(n_agent, dtype=jnp.int32)])
|
||||
T = jnp.where(actors[:,None,None]==0, T_human[None], T_agent[None]) # (n,6,6)
|
||||
dwell_p = jnp.where(actors[:,None,None]==0, dwell_human[None], dwell_agent[None]) # (n,6,2)
|
||||
|
||||
def step(carry, _):
|
||||
s, active, k = carry
|
||||
k, k1, k2 = jax.random.split(k, 3)
|
||||
probs = T[jnp.arange(n), s] # (n,6)
|
||||
nxt = jax.random.categorical(k1, jnp.log(probs + 1e-10))
|
||||
nxt = jnp.where(active, nxt, -1)
|
||||
shape = dwell_p[jnp.arange(n), s, 0]
|
||||
scale = dwell_p[jnp.arange(n), s, 1]
|
||||
dwell = jnp.maximum(0.3, jax.random.gamma(k2, shape) * scale)
|
||||
still = active & (nxt != TERM_IDX) & (nxt >= 0)
|
||||
return (nxt, still, k), (nxt, dwell)
|
||||
|
||||
init = (jnp.zeros(n, dtype=jnp.int32), jnp.ones(n, dtype=jnp.bool_), k3)
|
||||
_, (states, dwells) = lax.scan(step, init, None, length=max_steps)
|
||||
states, dwells = states.T, dwells.T # (n, max_steps)
|
||||
is_term = (states == -1) | (states == TERM_IDX)
|
||||
lengths = jnp.argmax(is_term, axis=1) + 1
|
||||
lengths = jnp.where(jnp.any(is_term, axis=1), lengths, max_steps)
|
||||
return states, dwells, actors, lengths
|
||||
|
||||
def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_products: int, max_steps: int = 40) -> SessionBatch:
|
||||
if JAX_AVAILABLE:
|
||||
k1, k2 = jax.random.split(key)
|
||||
states, dwells, actors, lengths = _sample_sessions_jax(k1, trans.human_T, trans.agent_T, trans.human_dwell, trans.agent_dwell, n_human, n_agent, max_steps)
|
||||
products = jax.random.randint(k2, (n_human + n_agent,), 0, n_products)
|
||||
return SessionBatch(np.asarray(states), np.asarray(dwells), np.asarray(products), np.asarray(actors), np.asarray(lengths))
|
||||
# numpy fallback
|
||||
rng = np.random.default_rng(int(key[0]) if hasattr(key, '__getitem__') else 42)
|
||||
n = n_human + n_agent
|
||||
actors = np.concatenate([np.zeros(n_human, dtype=np.int32), np.ones(n_agent, dtype=np.int32)])
|
||||
products = rng.integers(0, n_products, size=n)
|
||||
states, dwells = np.full((n, max_steps), -1, dtype=np.int32), np.zeros((n, max_steps), dtype=np.float32)
|
||||
lengths = np.zeros(n, dtype=np.int32)
|
||||
for i in range(n):
|
||||
T = trans.human_T if actors[i] == 0 else trans.agent_T
|
||||
dp = trans.human_dwell if actors[i] == 0 else trans.agent_dwell
|
||||
s, t = 0, 0
|
||||
while t < max_steps and s != TERM_IDX:
|
||||
states[i, t] = s
|
||||
dwells[i, t] = max(0.3, rng.gamma(dp[s, 0], dp[s, 1]))
|
||||
s = rng.choice(N_STATES, p=T[s])
|
||||
t += 1
|
||||
lengths[i] = t
|
||||
return SessionBatch(states, dwells, products, actors, lengths)
|
||||
|
||||
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult:
|
||||
purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
|
||||
human_mask, agent_mask = batch.actors == 0, batch.actors == 1
|
||||
human_purch = purchased & human_mask
|
||||
agent_purch = purchased & agent_mask
|
||||
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
|
||||
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
|
||||
revenue = float(np.sum(prices[batch.products[purchased]]))
|
||||
return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch)
|
||||
47
sim/rl/jax_core/transitions.py
Normal file
47
sim/rl/jax_core/transitions.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Dense transition matrices for JAX Markov chain sampling."""
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
JAX_AVAILABLE = True
|
||||
except ImportError:
|
||||
jnp, JAX_AVAILABLE = np, False
|
||||
|
||||
STATES = ["session_start", "view_item_page", "learn_more_about_item", "add_item_to_cart", "purchase_complete", "session_end"]
|
||||
S2I = {s: i for i, s in enumerate(STATES)}
|
||||
N_STATES, TERM_IDX, PURCHASE_IDX, CART_IDX = len(STATES), 5, 4, 3
|
||||
|
||||
@dataclass
|
||||
class TransitionData:
|
||||
human_T: np.ndarray # (6,6) transition probs
|
||||
agent_T: np.ndarray # (6,6)
|
||||
human_dwell: np.ndarray # (6,2) shape,scale
|
||||
agent_dwell: np.ndarray # (6,2)
|
||||
|
||||
def to_jax(self):
|
||||
if not JAX_AVAILABLE: return self
|
||||
return TransitionData(*[jnp.array(x) for x in [self.human_T, self.agent_T, self.human_dwell, self.agent_dwell]])
|
||||
|
||||
def dict_to_dense(d):
|
||||
m = np.zeros((N_STATES, N_STATES), dtype=np.float32)
|
||||
for src, dsts in d.items():
|
||||
if (i := S2I.get(src)) is not None:
|
||||
for dst, p in dsts.items():
|
||||
if (j := S2I.get(dst)) is not None: m[i,j] = p
|
||||
m /= np.maximum(m.sum(1, keepdims=True), 1e-8)
|
||||
m[TERM_IDX] = 0; m[TERM_IDX, TERM_IDX] = 1.0
|
||||
return m
|
||||
|
||||
def compile_transitions(human_profile, agent_profile):
|
||||
def dwell_arr(params): return np.array([[params.get(s, (2.0, 1.0)) for s in STATES]], dtype=np.float32).reshape(N_STATES, 2)
|
||||
return TransitionData(dict_to_dense(human_profile.transitions), dict_to_dense(agent_profile.transitions),
|
||||
dwell_arr(human_profile.dwell_params), dwell_arr(agent_profile.dwell_params))
|
||||
|
||||
def fallback_transitions():
|
||||
H = {"session_start": {"view_item_page": .85, "session_end": .15}, "view_item_page": {"learn_more_about_item": .4, "add_item_to_cart": .3, "view_item_page": .2, "session_end": .1},
|
||||
"learn_more_about_item": {"add_item_to_cart": .5, "view_item_page": .3, "session_end": .2}, "add_item_to_cart": {"purchase_complete": .6, "view_item_page": .25, "session_end": .15}, "purchase_complete": {"session_end": 1.0}}
|
||||
A = {"session_start": {"view_item_page": .9, "session_end": .1}, "view_item_page": {"learn_more_about_item": .5, "add_item_to_cart": .25, "view_item_page": .15, "session_end": .1},
|
||||
"learn_more_about_item": {"add_item_to_cart": .4, "view_item_page": .4, "session_end": .2}, "add_item_to_cart": {"purchase_complete": .5, "view_item_page": .3, "session_end": .2}, "purchase_complete": {"session_end": 1.0}}
|
||||
dwell = np.full((N_STATES, 2), [2.0, 1.0], dtype=np.float32)
|
||||
return TransitionData(dict_to_dense(H), dict_to_dense(A), dwell.copy(), dwell.copy())
|
||||
Reference in New Issue
Block a user