mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: baseline setup for RL modeling
This commit is contained in:
@@ -1,7 +1,13 @@
|
|||||||
import pandas as pd
|
from __future__ import annotations
|
||||||
import random
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from lib.separability import estimate_alpha, load_artifacts, score_session
|
||||||
|
|
||||||
# use relative import when in package context, fallback for standalone
|
# use relative import when in package context, fallback for standalone
|
||||||
try:
|
try:
|
||||||
@@ -15,6 +21,11 @@ except ImportError:
|
|||||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||||
AGENT_DATA_DIR = Path(os.getenv('PHANTOM_AGENT_DATA_DIR', PROJECT_ROOT / "experiments" / "agents" / "collected_data"))
|
AGENT_DATA_DIR = Path(os.getenv('PHANTOM_AGENT_DATA_DIR', PROJECT_ROOT / "experiments" / "agents" / "collected_data"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
SEPARABILITY_ARTIFACTS = load_artifacts()
|
||||||
|
except FileNotFoundError:
|
||||||
|
SEPARABILITY_ARTIFACTS = None
|
||||||
|
|
||||||
|
|
||||||
def remap_schema(df: pd.DataFrame, mapping: dict, on: str = "event_type") -> pd.DataFrame:
|
def remap_schema(df: pd.DataFrame, mapping: dict, on: str = "event_type") -> pd.DataFrame:
|
||||||
"""remap column values according to mapping dict, preserving unmapped values"""
|
"""remap column values according to mapping dict, preserving unmapped values"""
|
||||||
@@ -23,6 +34,24 @@ def remap_schema(df: pd.DataFrame, mapping: dict, on: str = "event_type") -> pd.
|
|||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _states_to_events(states: list[str]) -> list[SimpleNamespace]:
|
||||||
|
events: list[SimpleNamespace] = []
|
||||||
|
for idx, state in enumerate(states):
|
||||||
|
parts = state.split("|") if isinstance(state, str) else ["page", "product", str(state)]
|
||||||
|
page = f"/{parts[0]}" if parts else "/"
|
||||||
|
product = parts[1] if len(parts) > 1 else "unknown"
|
||||||
|
event_name = parts[2] if len(parts) > 2 else parts[-1]
|
||||||
|
events.append(
|
||||||
|
SimpleNamespace(
|
||||||
|
eventName=event_name,
|
||||||
|
page=page,
|
||||||
|
productId=product,
|
||||||
|
ts=float(idx),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
def contaminate_dataset(df: pd.DataFrame, on: str = "event_type",
|
def contaminate_dataset(df: pd.DataFrame, on: str = "event_type",
|
||||||
contamination_rate: float = 0.1,
|
contamination_rate: float = 0.1,
|
||||||
agent_data_dir: Path = None) -> pd.DataFrame:
|
agent_data_dir: Path = None) -> pd.DataFrame:
|
||||||
@@ -48,6 +77,7 @@ def contaminate_dataset(df: pd.DataFrame, on: str = "event_type",
|
|||||||
|
|
||||||
# generate synthetic trajectories
|
# generate synthetic trajectories
|
||||||
new_rows = []
|
new_rows = []
|
||||||
|
alpha_estimates = []
|
||||||
for start_event in start_events:
|
for start_event in start_events:
|
||||||
# sample trajectory from agent model, using a state that contains the event type
|
# sample trajectory from agent model, using a state that contains the event type
|
||||||
mdp_states = model.mdp.get('states', []) if model.mdp else []
|
mdp_states = model.mdp.get('states', []) if model.mdp else []
|
||||||
@@ -56,11 +86,28 @@ def contaminate_dataset(df: pd.DataFrame, on: str = "event_type",
|
|||||||
continue # skip if no matching start state
|
continue # skip if no matching start state
|
||||||
start_state = random.choice(matching_starts)
|
start_state = random.choice(matching_starts)
|
||||||
trajectory = model.sample_traj(start_state, max_len=20)
|
trajectory = model.sample_traj(start_state, max_len=20)
|
||||||
|
score_payload: list[SimpleNamespace] = []
|
||||||
|
score: dict[str, float] = {}
|
||||||
|
if SEPARABILITY_ARTIFACTS:
|
||||||
|
score_payload = _states_to_events(trajectory)
|
||||||
|
score = score_session(score_payload, SEPARABILITY_ARTIFACTS)
|
||||||
|
alpha_estimates.append(
|
||||||
|
estimate_alpha(score["prob_agent"], score["delta_h"], score["delta_a"], temperature=2.0)
|
||||||
|
)
|
||||||
|
|
||||||
for state in trajectory:
|
for state in trajectory:
|
||||||
parts = state.split('|') # page|productId|eventName format
|
parts = state.split('|') if isinstance(state, str) else [start_event]
|
||||||
new_rows.append({on: parts[-1] if parts else start_event, 'source': 'synthetic_agent'})
|
new_rows.append({
|
||||||
|
on: parts[-1] if parts else start_event,
|
||||||
|
'source': 'synthetic_agent',
|
||||||
|
'prob_agent': score.get('prob_agent') if SEPARABILITY_ARTIFACTS and score_payload else None,
|
||||||
|
'delta_h': score.get('delta_h') if SEPARABILITY_ARTIFACTS and score_payload else None,
|
||||||
|
'delta_a': score.get('delta_a') if SEPARABILITY_ARTIFACTS and score_payload else None,
|
||||||
|
})
|
||||||
|
|
||||||
if new_rows:
|
if new_rows:
|
||||||
contaminate_df = pd.DataFrame(new_rows)
|
contaminate_df = pd.DataFrame(new_rows)
|
||||||
df = pd.concat([df, contaminate_df], ignore_index=True)
|
df = pd.concat([df, contaminate_df], ignore_index=True)
|
||||||
|
if alpha_estimates:
|
||||||
|
df['estimated_alpha'] = sum(alpha_estimates) / len(alpha_estimates)
|
||||||
return df
|
return df
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
from os import kill
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from environment import BusinessLogicConstraints
|
from sim.rl.environment import BusinessLogicConstraints
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -32,9 +31,11 @@ class BasePricingEngine(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def update(self, observation: Dict[str, Any], reward: float, done: bool, info: Dict[str, Any]) -> None:
|
||||||
def update(obs, reward, done, info):
|
"""Default no-op update. Engines can override as needed."""
|
||||||
pass
|
self.last_observation = observation
|
||||||
|
self.last_reward = reward
|
||||||
|
self.last_info = info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -48,14 +49,14 @@ class WildPricingEngine(BasePricingEngine):
|
|||||||
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
||||||
super().__init__(constraints, seed)
|
super().__init__(constraints, seed)
|
||||||
# per-product unit costs (unknown to customers; known to platform)
|
# per-product unit costs (unknown to customers; known to platform)
|
||||||
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catelogue_size).astype(np.float32)
|
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catalogue_size).astype(np.float32)
|
||||||
# online elasticity estimate (start moderately elastic)
|
# online elasticity estimate (start moderately elastic)
|
||||||
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
|
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
|
||||||
# EWMA state for log-log regression
|
# EWMA state for log-log regression
|
||||||
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
|
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
# knobs typical in production
|
# knobs typical in production
|
||||||
self.lr = 0.08
|
self.lr = 0.08
|
||||||
self.ewma = 0.05
|
self.ewma = 0.05
|
||||||
@@ -67,16 +68,16 @@ class WildPricingEngine(BasePricingEngine):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
super().reset()
|
super().reset()
|
||||||
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
|
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
|
||||||
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
|
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
|
|
||||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
# extract demand signal (from env observation) as proxy for sales
|
# extract demand signal (from env observation) as proxy for sales
|
||||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||||
return self._update_from_demand(current_prices, demand)
|
return self._update_from_demand(current_prices, demand)
|
||||||
|
|
||||||
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
|
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
|
||||||
@@ -140,7 +141,7 @@ class SimpleDemandEngine(BasePricingEngine):
|
|||||||
|
|
||||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||||
if self.prev_demand is None:
|
if self.prev_demand is None:
|
||||||
self.prev_demand = demand.copy()
|
self.prev_demand = demand.copy()
|
||||||
return current_prices.copy()
|
return current_prices.copy()
|
||||||
@@ -187,15 +188,15 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
|||||||
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
||||||
super().__init__(constraints, seed)
|
super().__init__(constraints, seed)
|
||||||
self.n_price_levels = 5
|
self.n_price_levels = 5
|
||||||
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||||
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||||
self.price_grid = None
|
self.price_grid = None
|
||||||
self.last_actions = None
|
self.last_actions = None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
super().reset()
|
super().reset()
|
||||||
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||||
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||||
self.price_grid = None
|
self.price_grid = None
|
||||||
self.last_actions = None
|
self.last_actions = None
|
||||||
|
|
||||||
@@ -206,10 +207,10 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
|||||||
lo = current_prices * 0.7
|
lo = current_prices * 0.7
|
||||||
hi = current_prices * 1.3
|
hi = current_prices * 1.3
|
||||||
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
|
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
|
||||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||||
# update beliefs based on last action
|
# update beliefs based on last action
|
||||||
if self.last_actions is not None:
|
if self.last_actions is not None:
|
||||||
for i in range(self.c.product_catelogue_size):
|
for i in range(self.c.product_catalogue_size):
|
||||||
a = self.last_actions[i]
|
a = self.last_actions[i]
|
||||||
reward = demand[i]
|
reward = demand[i]
|
||||||
if reward > 0.5:
|
if reward > 0.5:
|
||||||
@@ -217,9 +218,9 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
|||||||
else:
|
else:
|
||||||
self.beta[i, a] += 1.0
|
self.beta[i, a] += 1.0
|
||||||
# thompson sampling: sample from posterior, pick best
|
# thompson sampling: sample from posterior, pick best
|
||||||
new_prices = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
new_prices = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||||
actions = np.zeros(self.c.product_catelogue_size, dtype=int)
|
actions = np.zeros(self.c.product_catalogue_size, dtype=int)
|
||||||
for i in range(self.c.product_catelogue_size):
|
for i in range(self.c.product_catalogue_size):
|
||||||
theta = self.rng.beta(self.alpha[i], self.beta[i]).astype(np.float32)
|
theta = self.rng.beta(self.alpha[i], self.beta[i]).astype(np.float32)
|
||||||
actions[i] = int(np.argmax(theta))
|
actions[i] = int(np.argmax(theta))
|
||||||
new_prices[i] = self.price_grid[i, actions[i]]
|
new_prices[i] = self.price_grid[i, actions[i]]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class BusinessLogicConstraints():
|
|||||||
product_catalogue_size: int = 100
|
product_catalogue_size: int = 100
|
||||||
episode_length: int = 200
|
episode_length: int = 200
|
||||||
sessions_per_step: int = 250
|
sessions_per_step: int = 250
|
||||||
agent_share: float = 0.25
|
agent_share: float = 0.5
|
||||||
agent_recon_multiplier: float = 6.0
|
agent_recon_multiplier: float = 6.0
|
||||||
agent_purchase_probability: float = 0.20
|
agent_purchase_probability: float = 0.20
|
||||||
coi_strength: float = 0.25
|
coi_strength: float = 0.25
|
||||||
@@ -43,13 +43,45 @@ def _sigmoid(x: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
EVENT_PAGE_MAP = {
|
EVENT_PAGE_MAP = {
|
||||||
"session_start": "/",
|
"session_start": "/",
|
||||||
|
"page_view": "/",
|
||||||
"view_item_page": "/products",
|
"view_item_page": "/products",
|
||||||
"learn_more_about_item": "/products/details",
|
"learn_more_about_item": "/products/details",
|
||||||
"add_item_to_cart": "/cart",
|
"add_item_to_cart": "/cart",
|
||||||
|
"checkout_start": "/checkout",
|
||||||
"purchase_complete": "/checkout",
|
"purchase_complete": "/checkout",
|
||||||
"session_end": "/checkout/success",
|
"session_end": "/checkout/success",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# map real collected event names to canonical simulation states
|
||||||
|
EVENT_CANONICAL_MAP = {
|
||||||
|
"page_view": "session_start",
|
||||||
|
"hover_over_paragraph": "view_item_page",
|
||||||
|
"hover_over_title": "view_item_page",
|
||||||
|
"view_item_page": "view_item_page",
|
||||||
|
"learn_more_about_item": "learn_more_about_item",
|
||||||
|
"add_item_to_cart": "add_item_to_cart",
|
||||||
|
"checkout_start": "purchase_complete",
|
||||||
|
"remove_item": "view_item_page",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _canonicalize_transitions(raw_trans: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]:
|
||||||
|
"""Map real event transition names to canonical simulation states."""
|
||||||
|
canonical: Dict[str, Dict[str, float]] = {}
|
||||||
|
for src, dsts in raw_trans.items():
|
||||||
|
src_canon = EVENT_CANONICAL_MAP.get(src, src)
|
||||||
|
if src_canon not in canonical:
|
||||||
|
canonical[src_canon] = {}
|
||||||
|
for dst, prob in dsts.items():
|
||||||
|
dst_canon = EVENT_CANONICAL_MAP.get(dst, dst)
|
||||||
|
canonical[src_canon][dst_canon] = canonical[src_canon].get(dst_canon, 0.0) + prob
|
||||||
|
# re-normalize after aggregation
|
||||||
|
for src in canonical:
|
||||||
|
total = sum(canonical[src].values())
|
||||||
|
if total > 0:
|
||||||
|
canonical[src] = {k: v / total for k, v in canonical[src].items()}
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
|
||||||
class BehavioralProfile:
|
class BehavioralProfile:
|
||||||
"""Synthetic Markov profile used to generate interaction sessions.
|
"""Synthetic Markov profile used to generate interaction sessions.
|
||||||
@@ -68,11 +100,23 @@ class BehavioralProfile:
|
|||||||
]
|
]
|
||||||
model = AgentBehaviorModel(agent_dir) if actor == "agents" else BehaviorModel(human_dir)
|
model = AgentBehaviorModel(agent_dir) if actor == "agents" else BehaviorModel(human_dir)
|
||||||
mdp = model.build_MDP()
|
mdp = model.build_MDP()
|
||||||
self.transitions = aggregate_event_transitions(mdp) if mdp.get("transitions") else self._fallback_transitions()
|
raw_trans = aggregate_event_transitions(mdp) if mdp.get("transitions") else {}
|
||||||
|
self.transitions = _canonicalize_transitions(raw_trans) if raw_trans else self._fallback_transitions()
|
||||||
|
self._ensure_terminal_states()
|
||||||
self.dwell_params = self._extract_dwell_params(mdp)
|
self.dwell_params = self._extract_dwell_params(mdp)
|
||||||
|
|
||||||
|
def _ensure_terminal_states(self):
|
||||||
|
# guarantee purchase_complete leads to session_end and session_start exists
|
||||||
|
if "purchase_complete" not in self.transitions:
|
||||||
|
self.transitions["purchase_complete"] = {"session_end": 1.0}
|
||||||
|
elif "session_end" not in self.transitions.get("purchase_complete", {}):
|
||||||
|
self.transitions["purchase_complete"]["session_end"] = 1.0
|
||||||
|
total = sum(self.transitions["purchase_complete"].values())
|
||||||
|
self.transitions["purchase_complete"] = {k: v/total for k, v in self.transitions["purchase_complete"].items()}
|
||||||
|
if "session_start" not in self.transitions:
|
||||||
|
self.transitions["session_start"] = {"view_item_page": 0.7, "learn_more_about_item": 0.2, "session_end": 0.1}
|
||||||
|
|
||||||
def _fallback_transitions(self) -> Dict[str, Dict[str, float]]:
|
def _fallback_transitions(self) -> Dict[str, Dict[str, float]]:
|
||||||
# sensible defaults if no data available
|
|
||||||
return {
|
return {
|
||||||
"session_start": {"view_item_page": 0.85, "session_end": 0.15},
|
"session_start": {"view_item_page": 0.85, "session_end": 0.15},
|
||||||
"view_item_page": {"learn_more_about_item": 0.4, "add_item_to_cart": 0.3, "view_item_page": 0.2, "session_end": 0.1},
|
"view_item_page": {"learn_more_about_item": 0.4, "add_item_to_cart": 0.3, "view_item_page": 0.2, "session_end": 0.1},
|
||||||
@@ -82,12 +126,16 @@ class BehavioralProfile:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _extract_dwell_params(self, mdp: Dict) -> Dict[str, Tuple[float, float]]:
|
def _extract_dwell_params(self, mdp: Dict) -> Dict[str, Tuple[float, float]]:
|
||||||
# derive gamma params (shape, scale) from state_rewards which encode temporal progression
|
|
||||||
state_vals = mdp.get("state_values", {})
|
state_vals = mdp.get("state_values", {})
|
||||||
params = {}
|
params = {}
|
||||||
for state in self.states:
|
for state in self.states:
|
||||||
|
# try canonical and raw state names
|
||||||
val = state_vals.get(state, 0.5)
|
val = state_vals.get(state, 0.5)
|
||||||
shape = 1.5 + val * 2.0 # higher progression -> longer dwell
|
for raw, canon in EVENT_CANONICAL_MAP.items():
|
||||||
|
if canon == state and raw in state_vals:
|
||||||
|
val = state_vals[raw]
|
||||||
|
break
|
||||||
|
shape = 1.5 + val * 2.0
|
||||||
scale = 0.8 + (1.0 - val) * 1.2
|
scale = 0.8 + (1.0 - val) * 1.2
|
||||||
params[state] = (shape, scale)
|
params[state] = (shape, scale)
|
||||||
return params
|
return params
|
||||||
@@ -434,7 +482,14 @@ class PHANTOMEnv(gym.Env):
|
|||||||
"elasticity": {
|
"elasticity": {
|
||||||
"price": init_prices,
|
"price": init_prices,
|
||||||
"demand": np.zeros((self.constraints.product_catalogue_size,), dtype=np.float32),
|
"demand": np.zeros((self.constraints.product_catalogue_size,), dtype=np.float32),
|
||||||
}
|
},
|
||||||
|
"market": {
|
||||||
|
"alpha_hat": np.array([self.constraints.agent_share], dtype=np.float32),
|
||||||
|
"revenue_rate": np.array([0.0], dtype=np.float32),
|
||||||
|
"conversion_rate": np.array([0.0], dtype=np.float32),
|
||||||
|
"price_volatility": np.array([0.0], dtype=np.float32),
|
||||||
|
},
|
||||||
|
"cost": self.commerce_platform.unit_cost.astype(np.float32),
|
||||||
}
|
}
|
||||||
return self.state, {}
|
return self.state, {}
|
||||||
|
|
||||||
@@ -459,6 +514,18 @@ class PHANTOMEnv(gym.Env):
|
|||||||
float(np.mean(np.abs((new_prices - self._prev_prices) / (self._prev_prices + 1e-6))))
|
float(np.mean(np.abs((new_prices - self._prev_prices) / (self._prev_prices + 1e-6))))
|
||||||
self._prev_prices = new_prices.copy()
|
self._prev_prices = new_prices.copy()
|
||||||
|
|
||||||
|
# update market observation features
|
||||||
|
total_demand = float(np.sum(demand_vector))
|
||||||
|
total_purchases = float(result.get("true_human_purchases", 0.0) + result.get("true_agent_purchases", 0.0))
|
||||||
|
conv_rate = total_purchases / max(total_demand, 1.0)
|
||||||
|
self.state["market"] = {
|
||||||
|
"alpha_hat": np.array([float(diagnostics.get("alpha_hat", self.commerce_platform.alpha_hat))], dtype=np.float32),
|
||||||
|
"revenue_rate": np.array([float(result.get("revenue_observed", 0.0))], dtype=np.float32),
|
||||||
|
"conversion_rate": np.array([float(np.clip(conv_rate, 0.0, 1.0))], dtype=np.float32),
|
||||||
|
"price_volatility": np.array([float(volatility)], dtype=np.float32),
|
||||||
|
}
|
||||||
|
self.state["cost"] = self.commerce_platform.unit_cost.astype(np.float32)
|
||||||
|
|
||||||
# extract metrics with safe defaults for incomplete simulation
|
# extract metrics with safe defaults for incomplete simulation
|
||||||
revenue_observed = float(result.get("revenue_observed", 0.0))
|
revenue_observed = float(result.get("revenue_observed", 0.0))
|
||||||
agent_loss = float(result.get("agent_loss", 0.0))
|
agent_loss = float(result.get("agent_loss", 0.0))
|
||||||
|
|||||||
Reference in New Issue
Block a user