chore: rough migration of environment configuration

This commit is contained in:
2026-01-26 14:12:41 +01:00
parent cd6c3d6006
commit fa2aca8b13
2 changed files with 216 additions and 677 deletions

View File

@@ -76,8 +76,7 @@ class WildPricingEngine(BasePricingEngine):
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
self.step_count += 1
# extract demand signal (from env observation) as proxy for sales
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
demand = _extract_demand(observation, self.c.product_catalogue_size)
return self._update_from_demand(current_prices, demand)
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
@@ -141,7 +140,7 @@ class SimpleDemandEngine(BasePricingEngine):
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
self.step_count += 1
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
demand = _extract_demand(observation, self.c.product_catalogue_size)
if self.prev_demand is None:
self.prev_demand = demand.copy()
return current_prices.copy()
@@ -207,7 +206,7 @@ class ThompsonSamplingEngine(BasePricingEngine):
lo = current_prices * 0.7
hi = current_prices * 1.3
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
demand = _extract_demand(observation, self.c.product_catalogue_size)
# update beliefs based on last action
if self.last_actions is not None:
for i in range(self.c.product_catalogue_size):
@@ -226,3 +225,14 @@ class ThompsonSamplingEngine(BasePricingEngine):
new_prices[i] = self.price_grid[i, actions[i]]
self.last_actions = actions
return np.clip(new_prices, self.c.system_min_price, self.c.system_max_price).astype(np.float32)
def _extract_demand(observation: Dict[str, Any], n: int) -> np.ndarray:
if "elasticity" in observation and isinstance(observation["elasticity"], dict):
d = observation["elasticity"].get("demand")
if d is not None:
return np.asarray(d, dtype=np.float32)
d = observation.get("demand")
if d is not None:
return np.asarray(d, dtype=np.float32)
return np.zeros(n, dtype=np.float32)