feat: baseline setup for RL modeling

This commit is contained in:
2026-01-22 12:52:41 +01:00
parent fa89347c4e
commit a6e6cc5d60
3 changed files with 152 additions and 37 deletions

View File

@@ -22,7 +22,7 @@ class BusinessLogicConstraints():
product_catalogue_size: int = 100
episode_length: int = 200
sessions_per_step: int = 250
agent_share: float = 0.25
agent_share: float = 0.5
agent_recon_multiplier: float = 6.0
agent_purchase_probability: float = 0.20
coi_strength: float = 0.25
@@ -43,13 +43,45 @@ def _sigmoid(x: np.ndarray) -> np.ndarray:
EVENT_PAGE_MAP = {
"session_start": "/",
"page_view": "/",
"view_item_page": "/products",
"learn_more_about_item": "/products/details",
"add_item_to_cart": "/cart",
"checkout_start": "/checkout",
"purchase_complete": "/checkout",
"session_end": "/checkout/success",
}
# map real collected event names to canonical simulation states
EVENT_CANONICAL_MAP = {
"page_view": "session_start",
"hover_over_paragraph": "view_item_page",
"hover_over_title": "view_item_page",
"view_item_page": "view_item_page",
"learn_more_about_item": "learn_more_about_item",
"add_item_to_cart": "add_item_to_cart",
"checkout_start": "purchase_complete",
"remove_item": "view_item_page",
}
def _canonicalize_transitions(raw_trans: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""Map real event transition names to canonical simulation states."""
canonical: Dict[str, Dict[str, float]] = {}
for src, dsts in raw_trans.items():
src_canon = EVENT_CANONICAL_MAP.get(src, src)
if src_canon not in canonical:
canonical[src_canon] = {}
for dst, prob in dsts.items():
dst_canon = EVENT_CANONICAL_MAP.get(dst, dst)
canonical[src_canon][dst_canon] = canonical[src_canon].get(dst_canon, 0.0) + prob
# re-normalize after aggregation
for src in canonical:
total = sum(canonical[src].values())
if total > 0:
canonical[src] = {k: v / total for k, v in canonical[src].items()}
return canonical
class BehavioralProfile:
"""Synthetic Markov profile used to generate interaction sessions.
@@ -68,11 +100,23 @@ class BehavioralProfile:
]
model = AgentBehaviorModel(agent_dir) if actor == "agents" else BehaviorModel(human_dir)
mdp = model.build_MDP()
self.transitions = aggregate_event_transitions(mdp) if mdp.get("transitions") else self._fallback_transitions()
raw_trans = aggregate_event_transitions(mdp) if mdp.get("transitions") else {}
self.transitions = _canonicalize_transitions(raw_trans) if raw_trans else self._fallback_transitions()
self._ensure_terminal_states()
self.dwell_params = self._extract_dwell_params(mdp)
def _ensure_terminal_states(self):
# guarantee purchase_complete leads to session_end and session_start exists
if "purchase_complete" not in self.transitions:
self.transitions["purchase_complete"] = {"session_end": 1.0}
elif "session_end" not in self.transitions.get("purchase_complete", {}):
self.transitions["purchase_complete"]["session_end"] = 1.0
total = sum(self.transitions["purchase_complete"].values())
self.transitions["purchase_complete"] = {k: v/total for k, v in self.transitions["purchase_complete"].items()}
if "session_start" not in self.transitions:
self.transitions["session_start"] = {"view_item_page": 0.7, "learn_more_about_item": 0.2, "session_end": 0.1}
def _fallback_transitions(self) -> Dict[str, Dict[str, float]]:
# sensible defaults if no data available
return {
"session_start": {"view_item_page": 0.85, "session_end": 0.15},
"view_item_page": {"learn_more_about_item": 0.4, "add_item_to_cart": 0.3, "view_item_page": 0.2, "session_end": 0.1},
@@ -82,12 +126,16 @@ class BehavioralProfile:
}
def _extract_dwell_params(self, mdp: Dict) -> Dict[str, Tuple[float, float]]:
# derive gamma params (shape, scale) from state_rewards which encode temporal progression
state_vals = mdp.get("state_values", {})
params = {}
for state in self.states:
# try canonical and raw state names
val = state_vals.get(state, 0.5)
shape = 1.5 + val * 2.0 # higher progression -> longer dwell
for raw, canon in EVENT_CANONICAL_MAP.items():
if canon == state and raw in state_vals:
val = state_vals[raw]
break
shape = 1.5 + val * 2.0
scale = 0.8 + (1.0 - val) * 1.2
params[state] = (shape, scale)
return params
@@ -434,7 +482,14 @@ class PHANTOMEnv(gym.Env):
"elasticity": {
"price": init_prices,
"demand": np.zeros((self.constraints.product_catalogue_size,), dtype=np.float32),
}
},
"market": {
"alpha_hat": np.array([self.constraints.agent_share], dtype=np.float32),
"revenue_rate": np.array([0.0], dtype=np.float32),
"conversion_rate": np.array([0.0], dtype=np.float32),
"price_volatility": np.array([0.0], dtype=np.float32),
},
"cost": self.commerce_platform.unit_cost.astype(np.float32),
}
return self.state, {}
@@ -459,6 +514,18 @@ class PHANTOMEnv(gym.Env):
float(np.mean(np.abs((new_prices - self._prev_prices) / (self._prev_prices + 1e-6))))
self._prev_prices = new_prices.copy()
# update market observation features
total_demand = float(np.sum(demand_vector))
total_purchases = float(result.get("true_human_purchases", 0.0) + result.get("true_agent_purchases", 0.0))
conv_rate = total_purchases / max(total_demand, 1.0)
self.state["market"] = {
"alpha_hat": np.array([float(diagnostics.get("alpha_hat", self.commerce_platform.alpha_hat))], dtype=np.float32),
"revenue_rate": np.array([float(result.get("revenue_observed", 0.0))], dtype=np.float32),
"conversion_rate": np.array([float(np.clip(conv_rate, 0.0, 1.0))], dtype=np.float32),
"price_volatility": np.array([float(volatility)], dtype=np.float32),
}
self.state["cost"] = self.commerce_platform.unit_cost.astype(np.float32)
# extract metrics with safe defaults for incomplete simulation
revenue_observed = float(result.get("revenue_observed", 0.0))
agent_loss = float(result.get("agent_loss", 0.0))