chore: better test consistency before agnet

This commit is contained in:
2026-01-12 22:33:47 +01:00
parent 0d214a469f
commit 961302a21a
4 changed files with 89 additions and 3 deletions

View File

@@ -107,6 +107,36 @@ class SessionAwarePricer(PricingFunction):
return prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract elasticity, demand, and session features"""
if state_space is None or self.elasticity is None:
n = len(self.elasticity) if self.elasticity is not None else 0
return np.zeros((n, 5))
demand = np.asarray(state_space.demand)
n_products = len(demand)
# extract session features
velocity = 0.0
view_depth = 0.0
cart_to_view = 0.0
if not state_space.session_features.empty:
sf = state_space.session_features.iloc[0]
velocity = sf.get('interaction_velocity', 0.0)
view_depth = sf.get('product_view_depth', 0.0)
cart_to_view = sf.get('cart_to_view_ratio', 0.0)
# broadcast session features to all products
features = np.column_stack([
self.elasticity,
demand,
np.full(n_products, velocity),
np.full(n_products, view_depth),
np.full(n_products, cart_to_view)
])
return features
class ProductSpecificSessionPricer(PricingFunction):
"""
@@ -170,3 +200,12 @@ class ProductSpecificSessionPricer(PricingFunction):
prices = np.clip(base_prices, self.price_floor, self.price_ceil)
return prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract elasticity and demand features for product-specific pricing"""
if state_space is None or self.elasticity is None:
n = len(self.elasticity) if self.elasticity is not None else 0
return np.zeros((n, 2))
demand = np.asarray(state_space.demand)
return np.column_stack([self.elasticity, demand])