intorducing jax for computation

This commit is contained in:
2026-01-22 21:02:10 +01:00
parent 40e0b201e6
commit a033e77697
2 changed files with 76 additions and 12 deletions

View File

@@ -23,6 +23,11 @@ class SimResult(NamedTuple):
demand_human: np.ndarray
demand_agent: np.ndarray
revenue: float
revenue_oracle: float
agent_loss: float
coi: float
look_to_book: float
mean_sale_price: float
n_human_purchases: int
n_agent_purchases: int
sessions: SessionBatch
@@ -81,12 +86,31 @@ def sample_sessions(key, trans: TransitionData, n_human: int, n_agent: int, n_pr
lengths[i] = t
return SessionBatch(states, dwells, products, actors, lengths)
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray) -> SimResult:
def compute_metrics(batch: SessionBatch, prices: np.ndarray, unit_cost: np.ndarray, base_price: np.ndarray) -> SimResult:
purchased = np.any(batch.states == PURCHASE_IDX, axis=1)
human_mask, agent_mask = batch.actors == 0, batch.actors == 1
human_purch = purchased & human_mask
agent_purch = purchased & agent_mask
human_purch, agent_purch = purchased & human_mask, purchased & agent_mask
demand_h = np.bincount(batch.products[human_purch], minlength=len(prices)).astype(np.float32)
demand_a = np.bincount(batch.products[agent_purch], minlength=len(prices)).astype(np.float32)
revenue = float(np.sum(prices[batch.products[purchased]]))
return SimResult(demand_h, demand_a, revenue, int(human_purch.sum()), int(agent_purch.sum()), batch)
# revenue and oracle
purch_products = batch.products[purchased]
revenue = float(np.sum(prices[purch_products]))
revenue_oracle = float(np.sum(base_price[purch_products]))
# agent loss: base_price - price_paid for agent purchases (agents gaming the system)
agent_products = batch.products[agent_purch]
agent_loss = float(np.sum(base_price[agent_products] - prices[agent_products]))
# COI: margin - expected_premium*0.5 for human purchases
human_products = batch.products[human_purch]
if len(human_products) > 0:
margin = float(np.mean(prices[human_products] - unit_cost[human_products]))
premium = float(np.mean(base_price[human_products] - prices[human_products]))
coi = max(0.0, margin - premium * 0.5)
else:
coi = 0.0
# look to book: views / purchases
views = float(np.sum(batch.states == 1)) # view_item_page = index 1
n_purch = int(purchased.sum())
look_to_book = views / (n_purch + 1e-6)
mean_sale = float(np.mean(prices[purch_products])) if n_purch > 0 else 0.0
return SimResult(demand_h, demand_a, revenue, revenue_oracle, agent_loss, coi, look_to_book, mean_sale,
int(human_purch.sum()), int(agent_purch.sum()), batch)