chore: better test consistency before agnet

This commit is contained in:
2026-01-12 22:33:47 +01:00
parent 0d214a469f
commit 961302a21a
4 changed files with 89 additions and 3 deletions

View File

@@ -65,6 +65,11 @@ class StaticPricer(PricingFunction):
raise ValueError("Must call fit() or provide base_prices in constructor")
return self.base_prices.copy()
def _get_features(self, state_space=None) -> np.ndarray:
"""Static pricer uses no features, returns empty array"""
n = len(self.base_prices) if self.base_prices is not None else 0
return np.zeros((n, 0))
class RandomPricer(PricingFunction):
"""Random pricing within bounds (for baseline comparison)"""
@@ -87,6 +92,11 @@ class RandomPricer(PricingFunction):
self.n_products = len(state_space.demand)
return self.rng.uniform(self.price_min, self.price_max, size=self.n_products)
def _get_features(self, state_space=None) -> np.ndarray:
"""Random pricer uses no features"""
n = self.n_products if self.n_products else 0
return np.zeros((n, 0))
class SimpleSurgePricer(PricingFunction):
"""
@@ -133,3 +143,16 @@ class SimpleSurgePricer(PricingFunction):
new_prices[low_mask] *= self.discount_multiplier
return new_prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract demand and base price features for each product"""
if state_space is None:
n = len(self.base_prices) if self.base_prices is not None else 0
return np.zeros((n, 2))
demand = np.asarray(state_space.demand) if hasattr(state_space, 'demand') else np.array([0])
base = np.asarray(state_space.prices) if hasattr(state_space, 'prices') else self.base_prices
if base is None:
base = np.ones(len(demand)) * 99.99
return np.column_stack([demand, base])