intorducing jax for computation

This commit is contained in:
2026-01-22 21:02:10 +01:00
parent 40e0b201e6
commit a033e77697
2 changed files with 76 additions and 12 deletions

View File

@@ -9,6 +9,13 @@ from typing import Optional, Dict, Any, List, Tuple
from lib.separability import load_artifacts, score_session, estimate_alpha from lib.separability import load_artifacts, score_session, estimate_alpha
from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions
try:
import jax
from sim.rl.jax_core import JAX_AVAILABLE, compile_transitions, fallback_transitions, sample_sessions, compute_metrics
from sim.rl.jax_core import session_features, compute_session_transitions, compute_divergences, estimate_alpha_batch
except ImportError:
JAX_AVAILABLE = False
# "learner" agent learning to optimize pricing # "learner" agent learning to optimize pricing
# "agent" part of environment creating demand signals that learner processes # "agent" part of environment creating demand signals that learner processes
@@ -20,9 +27,9 @@ class BusinessLogicConstraints():
system_max_price: float = 500.0 system_max_price: float = 500.0
system_min_price: float = 1.0 system_min_price: float = 1.0
product_catalogue_size: int = 100 product_catalogue_size: int = 100
episode_length: int = 200 episode_length: int = 2000
sessions_per_step: int = 250 sessions_per_step: int = 250
agent_share: float = 0.5 agent_share: float = 0.2
agent_recon_multiplier: float = 6.0 agent_recon_multiplier: float = 6.0
agent_purchase_probability: float = 0.20 agent_purchase_probability: float = 0.20
coi_strength: float = 0.25 coi_strength: float = 0.25
@@ -423,9 +430,10 @@ class CommercePlatform:
class PHANTOMEnv(gym.Env): class PHANTOMEnv(gym.Env):
metadata = {"render_modes": []} metadata = {"render_modes": []}
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None): def __init__(self, constraints: Optional[BusinessLogicConstraints] = None, use_jax: bool = True):
super().__init__() super().__init__()
self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints() self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints()
self.use_jax = use_jax and JAX_AVAILABLE
self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment, self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment,
high=self.constraints.max_price_adjustment, high=self.constraints.max_price_adjustment,
shape=(self.constraints.product_catalogue_size,), dtype=np.float32) shape=(self.constraints.product_catalogue_size,), dtype=np.float32)
@@ -442,8 +450,8 @@ class PHANTOMEnv(gym.Env):
dtype=np.float32), dtype=np.float32),
}), }),
"market": spaces.Dict({ "market": spaces.Dict({
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), # estimated agent share "alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), # recent revenue "revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32),
"conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), "conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
"price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), "price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
}), }),
@@ -458,12 +466,27 @@ class PHANTOMEnv(gym.Env):
self.t = 0 self.t = 0
self._prev_prices: Optional[np.ndarray] = None self._prev_prices: Optional[np.ndarray] = None
self.state: Dict[str, Any] = {} self.state: Dict[str, Any] = {}
self._jax_key = None
self._jax_trans = None
if self.use_jax:
self._jax_key = jax.random.PRNGKey(self.constraints.seed)
self._init_jax_transitions()
def _init_jax_transitions(self):
try:
human_profile = _load_behavioral_profile("humans", np.ones(self.constraints.product_catalogue_size) * 0.1)
agent_profile = _load_behavioral_profile("agents", np.ones(self.constraints.product_catalogue_size) * 0.1)
self._jax_trans = compile_transitions(human_profile, agent_profile).to_jax()
except Exception:
self._jax_trans = fallback_transitions().to_jax()
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
if seed is not None: if seed is not None:
self._rng = np.random.default_rng(seed) self._rng = np.random.default_rng(seed)
self.commerce_platform._rng = np.random.default_rng(seed) self.commerce_platform._rng = np.random.default_rng(seed)
if self.use_jax:
self._jax_key = jax.random.PRNGKey(seed)
self.commerce_platform.alpha_hat = self.constraints.agent_share self.commerce_platform.alpha_hat = self.constraints.agent_share
self.t = 0 self.t = 0
init_prices = self._rng.uniform( init_prices = self._rng.uniform(
@@ -493,6 +516,20 @@ class PHANTOMEnv(gym.Env):
} }
return self.state, {} return self.state, {}
def _step_jax(self, new_prices: np.ndarray) -> Tuple[Dict, Dict]:
self._jax_key, subkey = jax.random.split(self._jax_key)
alpha = float(np.clip(self.commerce_platform.alpha_hat, 0.0, 0.95))
n_agent = max(1, int(self.constraints.sessions_per_step * alpha))
n_human = max(1, self.constraints.sessions_per_step - n_agent)
batch = sample_sessions(subkey, self._jax_trans, n_human, n_agent, len(new_prices))
sim = compute_metrics(batch, new_prices, self.commerce_platform.unit_cost, self.commerce_platform.base_price)
result = {"revenue_observed": sim.revenue, "revenue_oracle": sim.revenue_oracle,
"agent_loss": sim.agent_loss, "coi": sim.coi, "look_to_book": sim.look_to_book,
"mean_sale_price": sim.mean_sale_price, "true_human_purchases": sim.n_human_purchases,
"true_agent_purchases": sim.n_agent_purchases}
diagnostics = {"demand_human": sim.demand_human, "demand_agent": sim.demand_agent, "alpha_hat": alpha}
return result, diagnostics
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
self.t += 1 self.t += 1
base_prices = self.state["elasticity"]["price"].astype(np.float32) base_prices = self.state["elasticity"]["price"].astype(np.float32)
@@ -501,6 +538,9 @@ class PHANTOMEnv(gym.Env):
self.constraints.system_max_price).astype(np.float32) self.constraints.system_max_price).astype(np.float32)
self.state["elasticity"]["price"] = new_prices self.state["elasticity"]["price"] = new_prices
if self.use_jax:
result, diagnostics = self._step_jax(new_prices)
else:
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices) interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
result = self.commerce_platform.compute_interaction_features(interactions_df) result = self.commerce_platform.compute_interaction_features(interactions_df)
COI = float(result.get("coi", 0.0)) COI = float(result.get("coi", 0.0))

View File

@@ -23,6 +23,11 @@ class SimResult(NamedTuple):
demand_human: np.ndarray demand_human: np.ndarray
demand_agent: np.ndarray demand_agent: np.ndarray
revenue: float revenue: float
revenue_oracle: float
agent_loss: float
coi: float
look_to_book: float
mean_sale_price: float
n_human_purchases: int n_human_purchases: int
n_agent_purchases: int n_agent_purchases: int
sessions: SessionBatch sessions: SessionBatch
@@ -81,12 +86,31 @@ def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_pr
lengths[i] = t lengths[i] = t
return SessionBatch(states, dwells, products, actors, lengths) return SessionBatch(states, dwells, products, actors, lengths)
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult: def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult:
purchased = np.any(batch.states == PURCHASE_IDX, axis=1) purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
human_mask, agent_mask = batch.actors == 0, batch.actors == 1 human_mask, agent_mask = batch.actors == 0, batch.actors == 1
human_purch = purchased & human_mask human_purch, agent_purch = purchased & human_mask, purchased & agent_mask
agent_purch = purchased & agent_mask
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32) demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32) demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
revenue = float(np.sum(prices[batch.products[purchased]])) # revenue and oracle
return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch) purch_products = batch.products[purchased]
revenue = float(np.sum(prices[purch_products]))
revenue_oracle = float(np.sum(base_price[purch_products]))
# agent loss: base_price - price_paid for agent purchases (agents gaming the system)
agent_products = batch.products[agent_purch]
agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products]))
# COI: margin - expected_premium*0.5 for human purchases
human_products = batch.products[human_purch]
if len(human_products) > 0:
margin = float(np.mean(prices[human_products] - unit_cost[human_products]))
premium = float(np.mean(base_price[human_products] - prices[human_products]))
coi = max(0.0, margin - premium * 0.5)
else:
coi = 0.0
# look to book: views / purchases
views = float(np.sum(batch.states == 1)) # view_item_page = index 1
n_purch = int(purchased.sum())
look_to_book = views / (n_purch + 1e-6)
mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0
return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale,
int(human_purch.sum()), int(agent_purch.sum()), batch)