feature: e2e intro pipline surge pricing

This commit is contained in:
2025-12-06 16:30:28 +01:00
parent 503c5e182d
commit e6a5b95875
6 changed files with 41 additions and 110 deletions

View File

@@ -20,6 +20,7 @@ from procesing.steps import (
AggregatePriceLogsStep, AggregatePriceLogsStep,
JoinProductFeaturesStep, JoinProductFeaturesStep,
) )
from procesing.pricers.simple import SimpleSurgePricer
default_args = { default_args = {
'owner': 'phantom-research', 'owner': 'phantom-research',
@@ -75,6 +76,8 @@ def compute_demand(**kwargs):
context = get_context(**kwargs) context = get_context(**kwargs)
step = ComputeDemandStep(context) step = ComputeDemandStep(context)
demand_df = step.transform(df) demand_df = step.transform(df)
# TODO: clear the xcom
ti.xcom_push(key='demand_data', value=pickle.dumps(demand_df)) ti.xcom_push(key='demand_data', value=pickle.dumps(demand_df))
logging.info(f"Computed demand for {len(demand_df)} products") logging.info(f"Computed demand for {len(demand_df)} products")
@@ -113,46 +116,24 @@ def apply_surge_pricing(**kwargs):
product_features = pickle.loads(ti.xcom_pull(key='product_features')) product_features = pickle.loads(ti.xcom_pull(key='product_features'))
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {} dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
high_threshold = dag_conf.get('high_threshold', 10)
low_threshold = dag_conf.get('low_threshold', 2)
surge_multiplier = dag_conf.get('surge_multiplier', 1.2)
discount_multiplier = dag_conf.get('discount_multiplier', 0.9)
context = get_context(**kwargs) # rename demand_score to demand for pricer compatibility
products = context.products data = product_features.rename(columns={'demand_score': 'demand'})
results = []
for pid in product_features['productId'].unique(): surge_pricer = SimpleSurgePricer(
prod_data = product_features[product_features['productId'] == pid] high_threshold=dag_conf.get('high_threshold', 10),
if prod_data.empty: low_threshold=dag_conf.get('low_threshold', 2),
continue surge_multiplier=dag_conf.get('surge_multiplier', 1.2),
discount_multiplier=dag_conf.get('discount_multiplier', 0.9)
)
surge_pricer.fit(data)
data['optimal_price'] = surge_pricer.predict()
demand = prod_data["demand_score"].mean() prices_df = data[['productId', 'price', 'base_price', 'optimal_price', 'demand']].rename(columns={
current_price = prod_data["price"].mean() 'price': 'current_price',
'demand': 'demand_score'
})
prod_meta = products[products['id'] == pid]
if not prod_meta.empty:
meta = prod_meta.iloc[0]['metadata']
base_price = meta.get('base_price', current_price) if isinstance(meta, dict) else current_price
else:
base_price = current_price
if demand >= high_threshold:
optimal_price = base_price * surge_multiplier
elif demand <= low_threshold:
optimal_price = base_price * discount_multiplier
else:
optimal_price = base_price
results.append({
'productId': pid,
'current_price': current_price,
'base_price': base_price,
'optimal_price': optimal_price,
'demand_score': demand
})
prices_df = pd.DataFrame(results)
ti.xcom_push(key='predicted_prices', value=pickle.dumps(prices_df)) ti.xcom_push(key='predicted_prices', value=pickle.dumps(prices_df))
logging.info(f"Applied surge pricing for {len(prices_df)} products") logging.info(f"Applied surge pricing for {len(prices_df)} products")
return len(prices_df) return len(prices_df)

View File

@@ -18,6 +18,7 @@ from procesing.steps import (
ComputeDemandStep, ComputeDemandStep,
JoinProductFeaturesStep JoinProductFeaturesStep
) )
from procesing.pricers import SimpleSurgePricer
def interaction_extraction_pipeline(context: PipelineContext): def interaction_extraction_pipeline(context: PipelineContext):
"""Pipeline for extracting and augmenting interaction data""" """Pipeline for extracting and augmenting interaction data"""
@@ -57,65 +58,14 @@ def pricing_pipeline(context: "PipelineContext",
low_threshold: int = 2, low_threshold: int = 2,
surge_multiplier: float = 1.2, surge_multiplier: float = 1.2,
discount_multiplier: float = 0.9) -> pd.DataFrame: discount_multiplier: float = 0.9) -> pd.DataFrame:
"""
Generate product-level optimal prices using simple surge pricing rules.
Replaces complex Bayesian curve fitting with threshold-based adjustments.
Args:
context: Pipeline context
data: DataFrame with [productId, demand_score, price]
high_threshold: Demand threshold for surge pricing (default 10)
low_threshold: Demand threshold for discounts (default 2)
surge_multiplier: Price multiplier for high demand (default 1.2 = +20%)
discount_multiplier: Price multiplier for low demand (default 0.9 = -10%)
Returns:
DataFrame with [productId, current_price, optimal_price, demand_score]
"""
if data.empty or 'productId' not in data.columns: if data.empty or 'productId' not in data.columns:
return pd.DataFrame() return pd.DataFrame()
products = context.products surge_pricer = SimpleSurgePricer()
results = [] surge_pricer.fit(data)
data['optimal_price'] = surge_pricer.predict()
for pid in data['productId'].unique(): return data
prod_data = data[data['productId'] == pid]
if prod_data.empty:
continue
demand = prod_data["demand_score"].mean()
current_price = prod_data["price"].mean()
# get base price from metadata or use current price
prod_meta = products[products['id'] == pid]
if not prod_meta.empty:
meta = prod_meta.iloc[0]['metadata']
base_price = meta.get('base_price', current_price) if isinstance(meta, dict) else current_price
else:
base_price = current_price
# apply surge rules
if demand >= high_threshold:
optimal_price = base_price * surge_multiplier
elif demand <= low_threshold:
optimal_price = base_price * discount_multiplier
else:
optimal_price = base_price
results.append({
'productId': pid,
'current_price': current_price,
'base_price': base_price,
'optimal_price': optimal_price,
'demand_score': demand
})
return pd.DataFrame(results)
def full_pipeline(context: PipelineContext, def full_pipeline(context: PipelineContext,
@@ -172,10 +122,6 @@ if __name__ == '__main__':
interactions_file = "messages(2).json" interactions_file = "messages(2).json"
prices_file = "messages(3).json" prices_file = "messages(3).json"
if topic == "interactions":
data = pd.read_json(path + interactions_file)
elif topic == "price_logs":
pd.read_json(path + prices_file)
data = pd.read_json(path + (interactions_file if topic == "user-interactions" else prices_file)) data = pd.read_json(path + (interactions_file if topic == "user-interactions" else prices_file))
data = [r['payload'] for r in data['value'].to_list()] data = [r['payload'] for r in data['value'].to_list()]
data = pd.DataFrame(data) data = pd.DataFrame(data)

View File

@@ -1,6 +1,6 @@
from procesing.pricers.base import PricingFunction from procesing.pricers.base import PricingFunction
from procesing.pricers.elasticity import ElasticityBasedPricer from procesing.pricers.elasticity import ElasticityBasedPricer
from procesing.pricers.simple import StaticPricer, RandomPricer from procesing.pricers.simple import StaticPricer, RandomPricer, SimpleSurgePricer
from procesing.pricers.session_aware import SessionAwarePricer, ProductSpecificSessionPricer from procesing.pricers.session_aware import SessionAwarePricer, ProductSpecificSessionPricer
__all__ = [ __all__ = [
@@ -8,6 +8,7 @@ __all__ = [
'ElasticityBasedPricer', 'ElasticityBasedPricer',
'StaticPricer', 'StaticPricer',
'RandomPricer', 'RandomPricer',
'SimpleSurgePricer',
'SessionAwarePricer', 'SessionAwarePricer',
'ProductSpecificSessionPricer' 'ProductSpecificSessionPricer'
] ]

View File

@@ -25,7 +25,7 @@ class PricingFunction(ABC):
""" """
@abstractmethod @abstractmethod
def fit(self, historical_data: pd.DataFrame, **kwargs): def fit(self, *kwargs):
""" """
Offline training on historical data. Offline training on historical data.
@@ -36,7 +36,7 @@ class PricingFunction(ABC):
pass pass
@abstractmethod @abstractmethod
def predict(self, state_space) -> np.ndarray: def predict(self, *kwargs) -> np.ndarray:
""" """
Generate optimal prices given current state. Generate optimal prices given current state.

View File

@@ -67,24 +67,19 @@ class SimpleSurgePricer(PricingFunction):
self.surge_multiplier = surge_multiplier self.surge_multiplier = surge_multiplier
self.discount_multiplier = discount_multiplier self.discount_multiplier = discount_multiplier
def fit(self, historical_data: pd.DataFrame): def fit(self, market_data : pd.DataFrame):
"""Extract base prices from product catalog or historical averages""" """Extract base prices from product catalog or historical averages"""
if 'base_price' in historical_data.columns: self.base_prices = market_data['base_price'].to_numpy() if 'base_price' in market_data.columns else market_data['price'].values
self.base_prices = historical_data['base_price'].values self.demand_history = market_data['demand'].to_numpy() if 'demand' in market_data.columns else np.zeros_like(self.base_prices)
elif 'price' in historical_data.columns:
self.base_prices = historical_data.groupby('productId')['price'].mean().values
else:
raise ValueError("historical_data must contain 'base_price' or 'price'")
return self
def predict(self, state_space) -> np.ndarray: def predict(self) -> np.ndarray:
""" """
Adjust prices based on current demand using surge rules. Adjust prices based on current demand using surge rules.
state_space.demand: demand counts per product state_space.demand: demand counts per product
state_space.prices: current prices (fallback if base_prices not set) state_space.prices: current prices (fallback if base_prices not set)
""" """
current_prices = self.base_prices if self.base_prices is not None else state_space.prices current_prices = self.base_prices if self.base_prices is not None else np.ones_like(demand_vector) * 99.99
demand = state_space.demand demand = self.demand_history if self.demand_history is not None else np.zeros_like(current_prices)
new_prices = current_prices.copy() new_prices = current_prices.copy()
high_mask = demand >= self.high_threshold high_mask = demand >= self.high_threshold

View File

@@ -45,6 +45,14 @@ class JoinProductFeaturesStep(BaseContextStep):
""" """
demand_df, price_df = data demand_df, price_df = data
# get base prices from products if available
products = self.context.products
products['base_price'] = products.apply(
lambda row: float(row['metadata'].get('base_price', 0.0)) if isinstance(row['metadata'], dict) else 0,
axis=1
)
products = products[['id', 'base_price']].rename(columns={'id': 'productId'})
if price_df.empty: if price_df.empty:
return demand_df return demand_df
return demand_df.merge(price_df, on='productId', how='left') return demand_df.merge(price_df, on='productId', how='left').merge(products, on='productId', how='left')