mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
chore: better test consistency before agnet
This commit is contained in:
@@ -65,6 +65,11 @@ class StaticPricer(PricingFunction):
|
||||
raise ValueError("Must call fit() or provide base_prices in constructor")
|
||||
return self.base_prices.copy()
|
||||
|
||||
def _get_features(self, state_space=None) -> np.ndarray:
|
||||
"""Static pricer uses no features, returns empty array"""
|
||||
n = len(self.base_prices) if self.base_prices is not None else 0
|
||||
return np.zeros((n, 0))
|
||||
|
||||
|
||||
class RandomPricer(PricingFunction):
|
||||
"""Random pricing within bounds (for baseline comparison)"""
|
||||
@@ -87,6 +92,11 @@ class RandomPricer(PricingFunction):
|
||||
self.n_products = len(state_space.demand)
|
||||
return self.rng.uniform(self.price_min, self.price_max, size=self.n_products)
|
||||
|
||||
def _get_features(self, state_space=None) -> np.ndarray:
|
||||
"""Random pricer uses no features"""
|
||||
n = self.n_products if self.n_products else 0
|
||||
return np.zeros((n, 0))
|
||||
|
||||
|
||||
class SimpleSurgePricer(PricingFunction):
|
||||
"""
|
||||
@@ -133,3 +143,16 @@ class SimpleSurgePricer(PricingFunction):
|
||||
new_prices[low_mask] *= self.discount_multiplier
|
||||
|
||||
return new_prices
|
||||
|
||||
def _get_features(self, state_space=None) -> np.ndarray:
|
||||
"""Extract demand and base price features for each product"""
|
||||
if state_space is None:
|
||||
n = len(self.base_prices) if self.base_prices is not None else 0
|
||||
return np.zeros((n, 2))
|
||||
|
||||
demand = np.asarray(state_space.demand) if hasattr(state_space, 'demand') else np.array([0])
|
||||
base = np.asarray(state_space.prices) if hasattr(state_space, 'prices') else self.base_prices
|
||||
if base is None:
|
||||
base = np.ones(len(demand)) * 99.99
|
||||
|
||||
return np.column_stack([demand, base])
|
||||
|
||||
Reference in New Issue
Block a user