mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
intorducing jax for computation
This commit is contained in:
@@ -9,6 +9,13 @@ from typing import Optional, Dict, Any, List, Tuple
|
||||
from lib.separability import load_artifacts, score_session, estimate_alpha
|
||||
from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions
|
||||
|
||||
try:
|
||||
import jax
|
||||
from sim.rl.jax_core import JAX_AVAILABLE, compile_transitions, fallback_transitions, sample_sessions, compute_metrics
|
||||
from sim.rl.jax_core import session_features, compute_session_transitions, compute_divergences, estimate_alpha_batch
|
||||
except ImportError:
|
||||
JAX_AVAILABLE = False
|
||||
|
||||
# "learner" agent learning to optimize pricing
|
||||
# "agent" part of environment creating demand signals that learner processes
|
||||
|
||||
@@ -20,9 +27,9 @@ class BusinessLogicConstraints():
|
||||
system_max_price: float = 500.0
|
||||
system_min_price: float = 1.0
|
||||
product_catalogue_size: int = 100
|
||||
episode_length: int = 200
|
||||
episode_length: int = 2000
|
||||
sessions_per_step: int = 250
|
||||
agent_share: float = 0.5
|
||||
agent_share: float = 0.2
|
||||
agent_recon_multiplier: float = 6.0
|
||||
agent_purchase_probability: float = 0.20
|
||||
coi_strength: float = 0.25
|
||||
@@ -423,9 +430,10 @@ class CommercePlatform:
|
||||
class PHANTOMEnv(gym.Env):
|
||||
metadata = {"render_modes": []}
|
||||
|
||||
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None):
|
||||
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None, use_jax: bool = True):
|
||||
super().__init__()
|
||||
self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints()
|
||||
self.use_jax = use_jax and JAX_AVAILABLE
|
||||
self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment,
|
||||
high=self.constraints.max_price_adjustment,
|
||||
shape=(self.constraints.product_catalogue_size,), dtype=np.float32)
|
||||
@@ -442,8 +450,8 @@ class PHANTOMEnv(gym.Env):
|
||||
dtype=np.float32),
|
||||
}),
|
||||
"market": spaces.Dict({
|
||||
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), # estimated agent share
|
||||
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), # recent revenue
|
||||
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32),
|
||||
"conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||
"price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||
}),
|
||||
@@ -458,12 +466,27 @@ class PHANTOMEnv(gym.Env):
|
||||
self.t = 0
|
||||
self._prev_prices: Optional[np.ndarray] = None
|
||||
self.state: Dict[str, Any] = {}
|
||||
self._jax_key = None
|
||||
self._jax_trans = None
|
||||
if self.use_jax:
|
||||
self._jax_key = jax.random.PRNGKey(self.constraints.seed)
|
||||
self._init_jax_transitions()
|
||||
|
||||
def _init_jax_transitions(self):
|
||||
try:
|
||||
human_profile = _load_behavioral_profile("humans", np.ones(self.constraints.product_catalogue_size) * 0.1)
|
||||
agent_profile = _load_behavioral_profile("agents", np.ones(self.constraints.product_catalogue_size) * 0.1)
|
||||
self._jax_trans = compile_transitions(human_profile, agent_profile).to_jax()
|
||||
except Exception:
|
||||
self._jax_trans = fallback_transitions().to_jax()
|
||||
|
||||
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
if seed is not None:
|
||||
self._rng = np.random.default_rng(seed)
|
||||
self.commerce_platform._rng = np.random.default_rng(seed)
|
||||
if self.use_jax:
|
||||
self._jax_key = jax.random.PRNGKey(seed)
|
||||
self.commerce_platform.alpha_hat = self.constraints.agent_share
|
||||
self.t = 0
|
||||
init_prices = self._rng.uniform(
|
||||
@@ -493,6 +516,20 @@ class PHANTOMEnv(gym.Env):
|
||||
}
|
||||
return self.state, {}
|
||||
|
||||
def _step_jax(self, new_prices: np.ndarray) -> Tuple[Dict, Dict]:
|
||||
self._jax_key, subkey = jax.random.split(self._jax_key)
|
||||
alpha = float(np.clip(self.commerce_platform.alpha_hat, 0.0, 0.95))
|
||||
n_agent = max(1, int(self.constraints.sessions_per_step * alpha))
|
||||
n_human = max(1, self.constraints.sessions_per_step - n_agent)
|
||||
batch = sample_sessions(subkey, self._jax_trans, n_human, n_agent, len(new_prices))
|
||||
sim = compute_metrics(batch, new_prices, self.commerce_platform.unit_cost, self.commerce_platform.base_price)
|
||||
result = {"revenue_observed": sim.revenue, "revenue_oracle": sim.revenue_oracle,
|
||||
"agent_loss": sim.agent_loss, "coi": sim.coi, "look_to_book": sim.look_to_book,
|
||||
"mean_sale_price": sim.mean_sale_price, "true_human_purchases": sim.n_human_purchases,
|
||||
"true_agent_purchases": sim.n_agent_purchases}
|
||||
diagnostics = {"demand_human": sim.demand_human, "demand_agent": sim.demand_agent, "alpha_hat": alpha}
|
||||
return result, diagnostics
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
self.t += 1
|
||||
base_prices = self.state["elasticity"]["price"].astype(np.float32)
|
||||
@@ -501,8 +538,11 @@ class PHANTOMEnv(gym.Env):
|
||||
self.constraints.system_max_price).astype(np.float32)
|
||||
|
||||
self.state["elasticity"]["price"] = new_prices
|
||||
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
|
||||
result = self.commerce_platform.compute_interaction_features(interactions_df)
|
||||
if self.use_jax:
|
||||
result, diagnostics = self._step_jax(new_prices)
|
||||
else:
|
||||
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
|
||||
result = self.commerce_platform.compute_interaction_features(interactions_df)
|
||||
COI = float(result.get("coi", 0.0))
|
||||
|
||||
demand_vector = diagnostics.get("demand_human", np.zeros_like(new_prices)) + diagnostics.get(
|
||||
|
||||
Reference in New Issue
Block a user