mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
70 lines
3.1 KiB
Python
70 lines
3.1 KiB
Python
"""Vectorized session feature extraction."""
|
|
import numpy as np
|
|
from .transitions import N_STATES, PURCHASE_IDX, CART_IDX
|
|
from .simulation import SessionBatch
|
|
|
|
try:
|
|
import jax.numpy as jnp
|
|
from jax import jit
|
|
JAX_AVAILABLE = True
|
|
except ImportError:
|
|
jnp, JAX_AVAILABLE = np, False
|
|
def jit(f): return f
|
|
|
|
@jit
|
|
def extract_features(states, dwells, lengths):
|
|
"""Extract per-session features. Returns (n_sess, 9) array."""
|
|
n, max_len = states.shape
|
|
mask = jnp.arange(max_len)[None,:] < lengths[:,None]
|
|
duration = jnp.sum(dwells * mask, axis=1)
|
|
total = lengths.astype(jnp.float32)
|
|
count = lambda idx: jnp.sum((states == idx) & mask, axis=1).astype(jnp.float32)
|
|
views, learn, carts, purchases = count(1), count(2), count(3), count(4)
|
|
velocity = total / (duration + 1e-6)
|
|
conversion = purchases / (views + 1e-6)
|
|
avg_dwell = duration / (total + 1e-6)
|
|
return jnp.stack([duration, avg_dwell, total, velocity, views, carts, purchases, learn, conversion], axis=1)
|
|
|
|
def session_features(batch: SessionBatch) -> np.ndarray:
|
|
if JAX_AVAILABLE:
|
|
return np.asarray(extract_features(jnp.array(batch.states), jnp.array(batch.dwells), jnp.array(batch.lengths)))
|
|
# numpy fallback
|
|
n, max_len = batch.states.shape
|
|
mask = np.arange(max_len)[None,:] < batch.lengths[:,None]
|
|
duration = np.sum(batch.dwells * mask, axis=1)
|
|
total = batch.lengths.astype(np.float32)
|
|
count = lambda idx: np.sum((batch.states == idx) & mask, axis=1).astype(np.float32)
|
|
views, learn, carts, purchases = count(1), count(2), count(3), count(4)
|
|
return np.stack([duration, duration/(total+1e-6), total, total/(duration+1e-6), views, carts, purchases, learn, purchases/(views+1e-6)], axis=1)
|
|
|
|
@jit
|
|
def session_transitions(states, lengths, n_states=N_STATES):
|
|
"""Compute empirical transition counts per session. Returns (n_sess, n_states, n_states)."""
|
|
n, max_len = states.shape
|
|
mask = jnp.arange(max_len - 1)[None,:] < (lengths[:,None] - 1)
|
|
src, dst = states[:, :-1], states[:, 1:]
|
|
# handle -1 padding by clamping to valid range
|
|
src_c, dst_c = jnp.clip(src, 0, n_states-1), jnp.clip(dst, 0, n_states-1)
|
|
valid = mask & (src >= 0) & (dst >= 0)
|
|
def per_session(i):
|
|
s, d, v = src_c[i], dst_c[i], valid[i]
|
|
trans = (jnp.eye(n_states)[s,:,None] * jnp.eye(n_states)[d,None,:]).sum(0) * v[:,None,None]
|
|
return trans.sum(0)
|
|
# vmap not ideal here, use manual loop for clarity
|
|
trans = jnp.stack([per_session(i) for i in range(n)])
|
|
row_sums = trans.sum(axis=-1, keepdims=True)
|
|
return trans / (row_sums + 1e-10)
|
|
|
|
def compute_session_transitions(batch: SessionBatch) -> np.ndarray:
|
|
if JAX_AVAILABLE:
|
|
return np.asarray(session_transitions(jnp.array(batch.states), jnp.array(batch.lengths)))
|
|
# numpy fallback
|
|
n, max_len = batch.states.shape
|
|
trans = np.zeros((n, N_STATES, N_STATES), dtype=np.float32)
|
|
for i in range(n):
|
|
for t in range(batch.lengths[i] - 1):
|
|
s, d = batch.states[i, t], batch.states[i, t+1]
|
|
if s >= 0 and d >= 0: trans[i, s, d] += 1
|
|
row_sums = trans.sum(axis=-1, keepdims=True)
|
|
return trans / (row_sums + 1e-10)
|