mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
chore: better test consistency before agnet
This commit is contained in:
@@ -107,6 +107,36 @@ class SessionAwarePricer(PricingFunction):
|
||||
|
||||
return prices
|
||||
|
||||
def _get_features(self, state_space=None) -> np.ndarray:
|
||||
"""Extract elasticity, demand, and session features"""
|
||||
if state_space is None or self.elasticity is None:
|
||||
n = len(self.elasticity) if self.elasticity is not None else 0
|
||||
return np.zeros((n, 5))
|
||||
|
||||
demand = np.asarray(state_space.demand)
|
||||
n_products = len(demand)
|
||||
|
||||
# extract session features
|
||||
velocity = 0.0
|
||||
view_depth = 0.0
|
||||
cart_to_view = 0.0
|
||||
|
||||
if not state_space.session_features.empty:
|
||||
sf = state_space.session_features.iloc[0]
|
||||
velocity = sf.get('interaction_velocity', 0.0)
|
||||
view_depth = sf.get('product_view_depth', 0.0)
|
||||
cart_to_view = sf.get('cart_to_view_ratio', 0.0)
|
||||
|
||||
# broadcast session features to all products
|
||||
features = np.column_stack([
|
||||
self.elasticity,
|
||||
demand,
|
||||
np.full(n_products, velocity),
|
||||
np.full(n_products, view_depth),
|
||||
np.full(n_products, cart_to_view)
|
||||
])
|
||||
return features
|
||||
|
||||
|
||||
class ProductSpecificSessionPricer(PricingFunction):
|
||||
"""
|
||||
@@ -170,3 +200,12 @@ class ProductSpecificSessionPricer(PricingFunction):
|
||||
|
||||
prices = np.clip(base_prices, self.price_floor, self.price_ceil)
|
||||
return prices
|
||||
|
||||
def _get_features(self, state_space=None) -> np.ndarray:
|
||||
"""Extract elasticity and demand features for product-specific pricing"""
|
||||
if state_space is None or self.elasticity is None:
|
||||
n = len(self.elasticity) if self.elasticity is not None else 0
|
||||
return np.zeros((n, 2))
|
||||
|
||||
demand = np.asarray(state_space.demand)
|
||||
return np.column_stack([self.elasticity, demand])
|
||||
|
||||
Reference in New Issue
Block a user