From a033e776973c435c3ebce672403c2d3a33b8ad9f Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 22 Jan 2026 21:02:10 +0100 Subject: [PATCH] intorducing jax for computation --- sim/rl/environment.py | 54 ++++++++++++++++++++++++++++++----- sim/rl/jax_core/simulation.py | 34 ++++++++++++++++++---- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/sim/rl/environment.py b/sim/rl/environment.py index f1a7f53..597359f 100644 --- a/sim/rl/environment.py +++ b/sim/rl/environment.py @@ -9,6 +9,13 @@ from typing import Optional, Dict, Any, List, Tuple from lib.separability import load_artifacts, score_session, estimate_alpha from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions +try: + import jax + from sim.rl.jax_core import JAX_AVAILABLE, compile_transitions, fallback_transitions, sample_sessions, compute_metrics + from sim.rl.jax_core import session_features, compute_session_transitions, compute_divergences, estimate_alpha_batch +except ImportError: + JAX_AVAILABLE = False + # "learner" agent learning to optimize pricing # "agent" part of environment creating demand signals that learner processes @@ -20,9 +27,9 @@ class BusinessLogicConstraints(): system_max_price: float = 500.0 system_min_price: float = 1.0 product_catalogue_size: int = 100 - episode_length: int = 200 + episode_length: int = 2000 sessions_per_step: int = 250 - agent_share: float = 0.5 + agent_share: float = 0.2 agent_recon_multiplier: float = 6.0 agent_purchase_probability: float = 0.20 coi_strength: float = 0.25 @@ -423,9 +430,10 @@ class CommercePlatform: class PHANTOMEnv(gym.Env): metadata = {"render_modes": []} - def __init__(self, constraints: Optional[BusinessLogicConstraints] = None): + def __init__(self, constraints: Optional[BusinessLogicConstraints] = None, use_jax: bool = True): super().__init__() self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints() + self.use_jax = use_jax and JAX_AVAILABLE self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment, high=self.constraints.max_price_adjustment, shape=(self.constraints.product_catalogue_size,), dtype=np.float32) @@ -442,8 +450,8 @@ class PHANTOMEnv(gym.Env): dtype=np.float32), }), "market": spaces.Dict({ - "alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), # estimated agent share - "revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), # recent revenue + "alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), + "revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), "conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), "price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), }), @@ -458,12 +466,27 @@ class PHANTOMEnv(gym.Env): self.t = 0 self._prev_prices: Optional[np.ndarray] = None self.state: Dict[str, Any] = {} + self._jax_key = None + self._jax_trans = None + if self.use_jax: + self._jax_key = jax.random.PRNGKey(self.constraints.seed) + self._init_jax_transitions() + + def _init_jax_transitions(self): + try: + human_profile = _load_behavioral_profile("humans", np.ones(self.constraints.product_catalogue_size) * 0.1) + agent_profile = _load_behavioral_profile("agents", np.ones(self.constraints.product_catalogue_size) * 0.1) + self._jax_trans = compile_transitions(human_profile, agent_profile).to_jax() + except Exception: + self._jax_trans = fallback_transitions().to_jax() def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): super().reset(seed=seed) if seed is not None: self._rng = np.random.default_rng(seed) self.commerce_platform._rng = np.random.default_rng(seed) + if self.use_jax: + self._jax_key = jax.random.PRNGKey(seed) self.commerce_platform.alpha_hat = self.constraints.agent_share self.t = 0 init_prices = self._rng.uniform( @@ -493,6 +516,20 @@ class PHANTOMEnv(gym.Env): } return self.state, {} + def _step_jax(self, new_prices: np.ndarray) -> Tuple[Dict, Dict]: + self._jax_key, subkey = jax.random.split(self._jax_key) + alpha = float(np.clip(self.commerce_platform.alpha_hat, 0.0, 0.95)) + n_agent = max(1, int(self.constraints.sessions_per_step * alpha)) + n_human = max(1, self.constraints.sessions_per_step - n_agent) + batch = sample_sessions(subkey, self._jax_trans, n_human, n_agent, len(new_prices)) + sim = compute_metrics(batch, new_prices, self.commerce_platform.unit_cost, self.commerce_platform.base_price) + result = {"revenue_observed": sim.revenue, "revenue_oracle": sim.revenue_oracle, + "agent_loss": sim.agent_loss, "coi": sim.coi, "look_to_book": sim.look_to_book, + "mean_sale_price": sim.mean_sale_price, "true_human_purchases": sim.n_human_purchases, + "true_agent_purchases": sim.n_agent_purchases} + diagnostics = {"demand_human": sim.demand_human, "demand_agent": sim.demand_agent, "alpha_hat": alpha} + return result, diagnostics + def step(self, action: np.ndarray): self.t += 1 base_prices = self.state["elasticity"]["price"].astype(np.float32) @@ -501,8 +538,11 @@ class PHANTOMEnv(gym.Env): self.constraints.system_max_price).astype(np.float32) self.state["elasticity"]["price"] = new_prices - interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices) - result = self.commerce_platform.compute_interaction_features(interactions_df) + if self.use_jax: + result, diagnostics = self._step_jax(new_prices) + else: + interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices) + result = self.commerce_platform.compute_interaction_features(interactions_df) COI = float(result.get("coi", 0.0)) demand_vector = diagnostics.get("demand_human", np.zeros_like(new_prices)) + diagnostics.get( diff --git a/sim/rl/jax_core/simulation.py b/sim/rl/jax_core/simulation.py index ee8ca6f..9532b3d 100644 --- a/sim/rl/jax_core/simulation.py +++ b/sim/rl/jax_core/simulation.py @@ -23,6 +23,11 @@ class SimResult(NamedTuple): demand_human: np.ndarray demand_agent: np.ndarray revenue: float + revenue_oracle: float + agent_loss: float + coi: float + look_to_book: float + mean_sale_price: float n_human_purchases: int n_agent_purchases: int sessions: SessionBatch @@ -81,12 +86,31 @@ def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_pr lengths[i] = t return SessionBatch(states, dwells, products, actors, lengths) -def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult: +def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult: purchased = np.any(batch.states == PURCHASE_IDX, axis=1) human_mask, agent_mask = batch.actors == 0, batch.actors == 1 - human_purch = purchased & human_mask - agent_purch = purchased & agent_mask + human_purch, agent_purch = purchased & human_mask, purchased & agent_mask demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32) demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32) - revenue = float(np.sum(prices[batch.products[purchased]])) - return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch) + # revenue and oracle + purch_products = batch.products[purchased] + revenue = float(np.sum(prices[purch_products])) + revenue_oracle = float(np.sum(base_price[purch_products])) + # agent loss: base_price - price_paid for agent purchases (agents gaming the system) + agent_products = batch.products[agent_purch] + agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products])) + # COI: margin - expected_premium*0.5 for human purchases + human_products = batch.products[human_purch] + if len(human_products) > 0: + margin = float(np.mean(prices[human_products] - unit_cost[human_products])) + premium = float(np.mean(base_price[human_products] - prices[human_products])) + coi = max(0.0, margin - premium * 0.5) + else: + coi = 0.0 + # look to book: views / purchases + views = float(np.sum(batch.states == 1)) # view_item_page = index 1 + n_purch = int(purchased.sum()) + look_to_book = views / (n_purch + 1e-6) + mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0 + return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale, + int(human_purch.sum()), int(agent_purch.sum()), batch)