mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: rough migration of environment configuration
This commit is contained in:
@@ -76,8 +76,7 @@ class WildPricingEngine(BasePricingEngine):
|
||||
|
||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||
self.step_count += 1
|
||||
# extract demand signal (from env observation) as proxy for sales
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
demand = _extract_demand(observation, self.c.product_catalogue_size)
|
||||
return self._update_from_demand(current_prices, demand)
|
||||
|
||||
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
|
||||
@@ -141,7 +140,7 @@ class SimpleDemandEngine(BasePricingEngine):
|
||||
|
||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||
self.step_count += 1
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
demand = _extract_demand(observation, self.c.product_catalogue_size)
|
||||
if self.prev_demand is None:
|
||||
self.prev_demand = demand.copy()
|
||||
return current_prices.copy()
|
||||
@@ -207,7 +206,7 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
||||
lo = current_prices * 0.7
|
||||
hi = current_prices * 1.3
|
||||
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
demand = _extract_demand(observation, self.c.product_catalogue_size)
|
||||
# update beliefs based on last action
|
||||
if self.last_actions is not None:
|
||||
for i in range(self.c.product_catalogue_size):
|
||||
@@ -226,3 +225,14 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
||||
new_prices[i] = self.price_grid[i, actions[i]]
|
||||
self.last_actions = actions
|
||||
return np.clip(new_prices, self.c.system_min_price, self.c.system_max_price).astype(np.float32)
|
||||
|
||||
|
||||
def _extract_demand(observation: Dict[str, Any], n: int) -> np.ndarray:
|
||||
if "elasticity" in observation and isinstance(observation["elasticity"], dict):
|
||||
d = observation["elasticity"].get("demand")
|
||||
if d is not None:
|
||||
return np.asarray(d, dtype=np.float32)
|
||||
d = observation.get("demand")
|
||||
if d is not None:
|
||||
return np.asarray(d, dtype=np.float32)
|
||||
return np.zeros(n, dtype=np.float32)
|
||||
|
||||
Reference in New Issue
Block a user