mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
intorducing jax for computation
This commit is contained in:
@@ -9,6 +9,13 @@ from typing import Optional, Dict, Any, List, Tuple
|
|||||||
from lib.separability import load_artifacts, score_session, estimate_alpha
|
from lib.separability import load_artifacts, score_session, estimate_alpha
|
||||||
from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions
|
from sim.rl.behavior_loader.models import AgentBehaviorModel, BehaviorModel, aggregate_event_transitions
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
from sim.rl.jax_core import JAX_AVAILABLE, compile_transitions, fallback_transitions, sample_sessions, compute_metrics
|
||||||
|
from sim.rl.jax_core import session_features, compute_session_transitions, compute_divergences, estimate_alpha_batch
|
||||||
|
except ImportError:
|
||||||
|
JAX_AVAILABLE = False
|
||||||
|
|
||||||
# "learner" agent learning to optimize pricing
|
# "learner" agent learning to optimize pricing
|
||||||
# "agent" part of environment creating demand signals that learner processes
|
# "agent" part of environment creating demand signals that learner processes
|
||||||
|
|
||||||
@@ -20,9 +27,9 @@ class BusinessLogicConstraints():
|
|||||||
system_max_price: float = 500.0
|
system_max_price: float = 500.0
|
||||||
system_min_price: float = 1.0
|
system_min_price: float = 1.0
|
||||||
product_catalogue_size: int = 100
|
product_catalogue_size: int = 100
|
||||||
episode_length: int = 200
|
episode_length: int = 2000
|
||||||
sessions_per_step: int = 250
|
sessions_per_step: int = 250
|
||||||
agent_share: float = 0.5
|
agent_share: float = 0.2
|
||||||
agent_recon_multiplier: float = 6.0
|
agent_recon_multiplier: float = 6.0
|
||||||
agent_purchase_probability: float = 0.20
|
agent_purchase_probability: float = 0.20
|
||||||
coi_strength: float = 0.25
|
coi_strength: float = 0.25
|
||||||
@@ -423,9 +430,10 @@ class CommercePlatform:
|
|||||||
class PHANTOMEnv(gym.Env):
|
class PHANTOMEnv(gym.Env):
|
||||||
metadata = {"render_modes": []}
|
metadata = {"render_modes": []}
|
||||||
|
|
||||||
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None):
|
def __init__(self, constraints: Optional[BusinessLogicConstraints] = None, use_jax: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints()
|
self.constraints = constraints if isinstance(constraints, BusinessLogicConstraints) else BusinessLogicConstraints()
|
||||||
|
self.use_jax = use_jax and JAX_AVAILABLE
|
||||||
self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment,
|
self.action_space = spaces.Box(low=-self.constraints.max_price_adjustment,
|
||||||
high=self.constraints.max_price_adjustment,
|
high=self.constraints.max_price_adjustment,
|
||||||
shape=(self.constraints.product_catalogue_size,), dtype=np.float32)
|
shape=(self.constraints.product_catalogue_size,), dtype=np.float32)
|
||||||
@@ -442,8 +450,8 @@ class PHANTOMEnv(gym.Env):
|
|||||||
dtype=np.float32),
|
dtype=np.float32),
|
||||||
}),
|
}),
|
||||||
"market": spaces.Dict({
|
"market": spaces.Dict({
|
||||||
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), # estimated agent share
|
"alpha_hat": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||||
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32), # recent revenue
|
"revenue_rate": spaces.Box(low=0.0, high=1e6, shape=(1,), dtype=np.float32),
|
||||||
"conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
"conversion_rate": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||||
"price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
"price_volatility": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
|
||||||
}),
|
}),
|
||||||
@@ -458,12 +466,27 @@ class PHANTOMEnv(gym.Env):
|
|||||||
self.t = 0
|
self.t = 0
|
||||||
self._prev_prices: Optional[np.ndarray] = None
|
self._prev_prices: Optional[np.ndarray] = None
|
||||||
self.state: Dict[str, Any] = {}
|
self.state: Dict[str, Any] = {}
|
||||||
|
self._jax_key = None
|
||||||
|
self._jax_trans = None
|
||||||
|
if self.use_jax:
|
||||||
|
self._jax_key = jax.random.PRNGKey(self.constraints.seed)
|
||||||
|
self._init_jax_transitions()
|
||||||
|
|
||||||
|
def _init_jax_transitions(self):
|
||||||
|
try:
|
||||||
|
human_profile = _load_behavioral_profile("humans", np.ones(self.constraints.product_catalogue_size) * 0.1)
|
||||||
|
agent_profile = _load_behavioral_profile("agents", np.ones(self.constraints.product_catalogue_size) * 0.1)
|
||||||
|
self._jax_trans = compile_transitions(human_profile, agent_profile).to_jax()
|
||||||
|
except Exception:
|
||||||
|
self._jax_trans = fallback_transitions().to_jax()
|
||||||
|
|
||||||
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
self._rng = np.random.default_rng(seed)
|
self._rng = np.random.default_rng(seed)
|
||||||
self.commerce_platform._rng = np.random.default_rng(seed)
|
self.commerce_platform._rng = np.random.default_rng(seed)
|
||||||
|
if self.use_jax:
|
||||||
|
self._jax_key = jax.random.PRNGKey(seed)
|
||||||
self.commerce_platform.alpha_hat = self.constraints.agent_share
|
self.commerce_platform.alpha_hat = self.constraints.agent_share
|
||||||
self.t = 0
|
self.t = 0
|
||||||
init_prices = self._rng.uniform(
|
init_prices = self._rng.uniform(
|
||||||
@@ -493,6 +516,20 @@ class PHANTOMEnv(gym.Env):
|
|||||||
}
|
}
|
||||||
return self.state, {}
|
return self.state, {}
|
||||||
|
|
||||||
|
def _step_jax(self, new_prices: np.ndarray) -> Tuple[Dict, Dict]:
|
||||||
|
self._jax_key, subkey = jax.random.split(self._jax_key)
|
||||||
|
alpha = float(np.clip(self.commerce_platform.alpha_hat, 0.0, 0.95))
|
||||||
|
n_agent = max(1, int(self.constraints.sessions_per_step * alpha))
|
||||||
|
n_human = max(1, self.constraints.sessions_per_step - n_agent)
|
||||||
|
batch = sample_sessions(subkey, self._jax_trans, n_human, n_agent, len(new_prices))
|
||||||
|
sim = compute_metrics(batch, new_prices, self.commerce_platform.unit_cost, self.commerce_platform.base_price)
|
||||||
|
result = {"revenue_observed": sim.revenue, "revenue_oracle": sim.revenue_oracle,
|
||||||
|
"agent_loss": sim.agent_loss, "coi": sim.coi, "look_to_book": sim.look_to_book,
|
||||||
|
"mean_sale_price": sim.mean_sale_price, "true_human_purchases": sim.n_human_purchases,
|
||||||
|
"true_agent_purchases": sim.n_agent_purchases}
|
||||||
|
diagnostics = {"demand_human": sim.demand_human, "demand_agent": sim.demand_agent, "alpha_hat": alpha}
|
||||||
|
return result, diagnostics
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
self.t += 1
|
self.t += 1
|
||||||
base_prices = self.state["elasticity"]["price"].astype(np.float32)
|
base_prices = self.state["elasticity"]["price"].astype(np.float32)
|
||||||
@@ -501,6 +538,9 @@ class PHANTOMEnv(gym.Env):
|
|||||||
self.constraints.system_max_price).astype(np.float32)
|
self.constraints.system_max_price).astype(np.float32)
|
||||||
|
|
||||||
self.state["elasticity"]["price"] = new_prices
|
self.state["elasticity"]["price"] = new_prices
|
||||||
|
if self.use_jax:
|
||||||
|
result, diagnostics = self._step_jax(new_prices)
|
||||||
|
else:
|
||||||
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
|
interactions_df, diagnostics = self.commerce_platform._simulate_sessions(new_prices)
|
||||||
result = self.commerce_platform.compute_interaction_features(interactions_df)
|
result = self.commerce_platform.compute_interaction_features(interactions_df)
|
||||||
COI = float(result.get("coi", 0.0))
|
COI = float(result.get("coi", 0.0))
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ class SimResult(NamedTuple):
|
|||||||
demand_human: np.ndarray
|
demand_human: np.ndarray
|
||||||
demand_agent: np.ndarray
|
demand_agent: np.ndarray
|
||||||
revenue: float
|
revenue: float
|
||||||
|
revenue_oracle: float
|
||||||
|
agent_loss: float
|
||||||
|
coi: float
|
||||||
|
look_to_book: float
|
||||||
|
mean_sale_price: float
|
||||||
n_human_purchases: int
|
n_human_purchases: int
|
||||||
n_agent_purchases: int
|
n_agent_purchases: int
|
||||||
sessions: SessionBatch
|
sessions: SessionBatch
|
||||||
@@ -81,12 +86,31 @@ def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_pr
|
|||||||
lengths[i] = t
|
lengths[i] = t
|
||||||
return SessionBatch(states, dwells, products, actors, lengths)
|
return SessionBatch(states, dwells, products, actors, lengths)
|
||||||
|
|
||||||
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult:
|
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult:
|
||||||
purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
|
purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
|
||||||
human_mask, agent_mask = batch.actors == 0, batch.actors == 1
|
human_mask, agent_mask = batch.actors == 0, batch.actors == 1
|
||||||
human_purch = purchased & human_mask
|
human_purch, agent_purch = purchased & human_mask, purchased & agent_mask
|
||||||
agent_purch = purchased & agent_mask
|
|
||||||
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
|
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
|
||||||
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
|
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
|
||||||
revenue = float(np.sum(prices[batch.products[purchased]]))
|
# revenue and oracle
|
||||||
return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch)
|
purch_products = batch.products[purchased]
|
||||||
|
revenue = float(np.sum(prices[purch_products]))
|
||||||
|
revenue_oracle = float(np.sum(base_price[purch_products]))
|
||||||
|
# agent loss: base_price - price_paid for agent purchases (agents gaming the system)
|
||||||
|
agent_products = batch.products[agent_purch]
|
||||||
|
agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products]))
|
||||||
|
# COI: margin - expected_premium*0.5 for human purchases
|
||||||
|
human_products = batch.products[human_purch]
|
||||||
|
if len(human_products) > 0:
|
||||||
|
margin = float(np.mean(prices[human_products] - unit_cost[human_products]))
|
||||||
|
premium = float(np.mean(base_price[human_products] - prices[human_products]))
|
||||||
|
coi = max(0.0, margin - premium * 0.5)
|
||||||
|
else:
|
||||||
|
coi = 0.0
|
||||||
|
# look to book: views / purchases
|
||||||
|
views = float(np.sum(batch.states == 1)) # view_item_page = index 1
|
||||||
|
n_purch = int(purchased.sum())
|
||||||
|
look_to_book = views / (n_purch + 1e-6)
|
||||||
|
mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0
|
||||||
|
return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale,
|
||||||
|
int(human_purch.sum()), int(agent_purch.sum()), batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user