intorducing jax for computation

This commit is contained in:
2026-01-22 21:02:10 +01:00
parent 40e0b201e6
commit a033e77697
2 changed files with 76 additions and 12 deletions

View File

@@ -9,6 +9,13 @@ from typing import Optional, Dict, Any, List, Tuple
from lib.separability import load_artifacts, score_session, estimate_alpha
from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions
try:
import jax
from sim.rl.jax_core import JAX_AVAILABLE, compile_transitions, fallback_transitions, sample_sessions, compute_metrics
from sim.rl.jax_core import session_features, compute_session_transitions, compute_divergences, estimate_alpha_batch
except ImportError:
JAX_AVAILABLE = False
# "learner" agent learning to optimize pricing
# "agent" part of environment creating demand signals that learner processes
@@ -20,9 +27,9 @@ class BusinessLogicConstraints():
system_max_price: float = 500.0
system_min_price: float = 1.0
product_catalogue_size: int = 100
episode_length: int = 200
episode_length: int = 2000
sessions_per_step: int = 250
agent_share: float = 0.5
agent_share: float = 0.2
agent_recon_multiplier: float = 6.0
agent_purchase_probability: float = 0.20
coi_strength: float = 0.25
@@ -423,9 +430,10 @@ class CommercePlatform:
class PHANTOMEnv(gym.Env):
metadata = {"render_modes": []}
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None):
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None, use_jax: bool = True):
super().__init__()
self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints()
self.use_jax = use_jax and JAX_AVAILABLE
self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment,
high=self.constraints.max_price_adjustment,
shape=(self.constraints.product_catalogue_size,), dtype=np.float32)
@@ -442,8 +450,8 @@ class PHANTOMEnv(gym.Env):
dtype=np.float32),
}),
"market": spaces.Dict({
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), # estimated agent share
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), # recent revenue
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32),
"conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
"price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
}),
@@ -458,12 +466,27 @@ class PHANTOMEnv(gym.Env):
self.t = 0
self._prev_prices: Optional[np.ndarray] = None
self.state: Dict[str, Any] = {}
self._jax_key = None
self._jax_trans = None
if self.use_jax:
self._jax_key = jax.random.PRNGKey(self.constraints.seed)
self._init_jax_transitions()
def _init_jax_transitions(self):
try:
human_profile = _load_behavioral_profile("humans", np.ones(self.constraints.product_catalogue_size) * 0.1)
agent_profile = _load_behavioral_profile("agents", np.ones(self.constraints.product_catalogue_size) * 0.1)
self._jax_trans = compile_transitions(human_profile, agent_profile).to_jax()
except Exception:
self._jax_trans = fallback_transitions().to_jax()
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if seed is not None:
self._rng = np.random.default_rng(seed)
self.commerce_platform._rng = np.random.default_rng(seed)
if self.use_jax:
self._jax_key = jax.random.PRNGKey(seed)
self.commerce_platform.alpha_hat = self.constraints.agent_share
self.t = 0
init_prices = self._rng.uniform(
@@ -493,6 +516,20 @@ class PHANTOMEnv(gym.Env):
}
return self.state, {}
def _step_jax(self, new_prices: np.ndarray) -> Tuple[Dict, Dict]:
self._jax_key, subkey = jax.random.split(self._jax_key)
alpha = float(np.clip(self.commerce_platform.alpha_hat, 0.0, 0.95))
n_agent = max(1, int(self.constraints.sessions_per_step * alpha))
n_human = max(1, self.constraints.sessions_per_step - n_agent)
batch = sample_sessions(subkey, self._jax_trans, n_human, n_agent, len(new_prices))
sim = compute_metrics(batch, new_prices, self.commerce_platform.unit_cost, self.commerce_platform.base_price)
result = {"revenue_observed": sim.revenue, "revenue_oracle": sim.revenue_oracle,
"agent_loss": sim.agent_loss, "coi": sim.coi, "look_to_book": sim.look_to_book,
"mean_sale_price": sim.mean_sale_price, "true_human_purchases": sim.n_human_purchases,
"true_agent_purchases": sim.n_agent_purchases}
diagnostics = {"demand_human": sim.demand_human, "demand_agent": sim.demand_agent, "alpha_hat": alpha}
return result, diagnostics
def step(self, action: np.ndarray):
self.t += 1
base_prices = self.state["elasticity"]["price"].astype(np.float32)
@@ -501,8 +538,11 @@ class PHANTOMEnv(gym.Env):
self.constraints.system_max_price).astype(np.float32)
self.state["elasticity"]["price"] = new_prices
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
result = self.commerce_platform.compute_interaction_features(interactions_df)
if self.use_jax:
result, diagnostics = self._step_jax(new_prices)
else:
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
result = self.commerce_platform.compute_interaction_features(interactions_df)
COI = float(result.get("coi", 0.0))
demand_vector = diagnostics.get("demand_human", np.zeros_like(new_prices)) + diagnostics.get(