Files
PHANTOM/experiments/airflow/dags/surge_pricing_pipeline.py

257 lines
8.4 KiB
Python

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from datetime import timedelta
import pandas as pd
import logging
import sys
import pickle
import io
# add parent dir to path so procesing package can be imported
sys.path.insert(0, '/opt/airflow')
from procesing.context import PipelineContext
from procesing.providers import SupabaseProvider, BackendAPIProvider
from procesing.steps import (
FetchInteractionsStep,
FetchPriceLogsStep,
ComputeDemandStep,
AggregatePriceLogsStep,
JoinProductFeaturesStep,
)
default_args = {
'owner': 'phantom-research',
'depends_on_past': False,
'email_on_failure': False,
'email_on_retry': False,
'retries': 2,
'retry_delay': timedelta(minutes=5),
}
def get_provider():
"""Factory to create composite provider"""
class CompositeProvider(SupabaseProvider, BackendAPIProvider): # TODO: Fix this into one global provider singelton instead of multiple inheritance declarations acoss the codebase
def __init__(self):
SupabaseProvider.__init__(self)
BackendAPIProvider.__init__(self)
return CompositeProvider()
def get_context(**kwargs):
"""Build pipeline context from Airflow config"""
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
return PipelineContext(
provider=get_provider(),
store_mode=dag_conf.get('store_mode', 'hotel'),
)
# atomic task functions (each wraps one sklearn step)
def fetch_interactions(**kwargs):
"""Task: Fetch interaction data from Kafka"""
context = get_context(**kwargs)
step = FetchInteractionsStep(context)
df = step.transform(None)
kwargs['ti'].xcom_push(key='interactions_raw', value=pickle.dumps(df))
logging.info(f"Fetched {len(df)} interaction records")
return len(df)
def fetch_price_logs(**kwargs):
"""Task: Fetch price logs from Kafka"""
context = get_context(**kwargs)
step = FetchPriceLogsStep(context)
df = step.transform(None)
kwargs['ti'].xcom_push(key='price_logs_raw', value=pickle.dumps(df))
logging.info(f"Fetched {len(df)} price records")
return len(df)
def compute_demand(**kwargs):
"""Task: Compute demand scores from interactions"""
ti = kwargs['ti']
df = pickle.loads(ti.xcom_pull(key='interactions_raw'))
context = get_context(**kwargs)
step = ComputeDemandStep(context)
demand_df = step.transform(df)
ti.xcom_push(key='demand_data', value=pickle.dumps(demand_df))
logging.info(f"Computed demand for {len(demand_df)} products")
return len(demand_df)
def aggregate_price_logs(**kwargs):
"""Task: Aggregate price logs"""
ti = kwargs['ti']
df = pickle.loads(ti.xcom_pull(key='price_logs_raw'))
context = get_context(**kwargs)
step = AggregatePriceLogsStep(context)
price_df = step.transform(df)
ti.xcom_push(key='price_data', value=pickle.dumps(price_df))
logging.info(f"Aggregated price logs for {len(price_df)} products")
return len(price_df)
def join_product_features(**kwargs):
"""Task: Join demand and price data"""
ti = kwargs['ti']
demand_df = pickle.loads(ti.xcom_pull(key='demand_data'))
price_df = pickle.loads(ti.xcom_pull(key='price_data'))
context = get_context(**kwargs)
step = JoinProductFeaturesStep(context)
joined_df = step.transform((demand_df, price_df))
ti.xcom_push(key='product_features', value=pickle.dumps(joined_df))
logging.info(f"Joined features for {len(joined_df)} products")
return len(joined_df)
def apply_surge_pricing(**kwargs):
"""Task: Apply surge pricing rules to generate optimal prices"""
ti = kwargs['ti']
product_features = pickle.loads(ti.xcom_pull(key='product_features'))
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
high_threshold = dag_conf.get('high_threshold', 10)
low_threshold = dag_conf.get('low_threshold', 2)
surge_multiplier = dag_conf.get('surge_multiplier', 1.2)
discount_multiplier = dag_conf.get('discount_multiplier', 0.9)
context = get_context(**kwargs)
products = context.products
results = []
for pid in product_features['productId'].unique():
prod_data = product_features[product_features['productId'] == pid]
if prod_data.empty:
continue
demand = prod_data["demand_score"].mean()
current_price = prod_data["price"].mean()
prod_meta = products[products['id'] == pid]
if not prod_meta.empty:
meta = prod_meta.iloc[0]['metadata']
base_price = meta.get('base_price', current_price) if isinstance(meta, dict) else current_price
else:
base_price = current_price
if demand >= high_threshold:
optimal_price = base_price * surge_multiplier
elif demand <= low_threshold:
optimal_price = base_price * discount_multiplier
else:
optimal_price = base_price
results.append({
'productId': pid,
'current_price': current_price,
'base_price': base_price,
'optimal_price': optimal_price,
'demand_score': demand
})
prices_df = pd.DataFrame(results)
ti.xcom_push(key='predicted_prices', value=pickle.dumps(prices_df))
logging.info(f"Applied surge pricing for {len(prices_df)} products")
return len(prices_df)
def publish_results(**kwargs):
"""Task: Publish surge pricing results to registry"""
ti = kwargs['ti']
prices_df = pickle.loads(ti.xcom_pull(key='predicted_prices'))
sys.path.insert(0, '/opt/airflow')
from lib.model_registry import ModelRegistry
registry = ModelRegistry()
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
metadata = {
'timestamp': pd.Timestamp.now().isoformat(),
'store_mode': dag_conf.get('store_mode', 'hotel'),
'dag_run_id': kwargs['dag_run'].run_id if kwargs.get('dag_run') else 'manual',
'pricing_method': 'surge',
'high_threshold': dag_conf.get('high_threshold', 10),
'low_threshold': dag_conf.get('low_threshold', 2),
'surge_multiplier': dag_conf.get('surge_multiplier', 1.2),
'discount_multiplier': dag_conf.get('discount_multiplier', 0.9)
}
registry.publish_prices(prices_df, model_name='latest', metadata=metadata)
logging.info(f"Published surge pricing for {len(prices_df)} products")
return {
'n_products': len(prices_df),
'registry_status': 'success',
'mean_demand': float(prices_df['demand_score'].mean()) if 'demand_score' in prices_df.columns else None
}
# DAG definition
with DAG(
'surge_pricing_pipeline',
default_args=default_args,
description='Simple surge pricing pipeline: demand aggregation + rule-based pricing',
schedule_interval='*/15 * * * *',
start_date=days_ago(1),
catchup=False,
max_active_runs=1,
tags=['pricing', 'surge', 'research', 'simplified'],
) as dag:
# parallel data fetching
t_fetch_interactions = PythonOperator(
task_id='fetch_interactions',
python_callable=fetch_interactions,
provide_context=True,
)
t_fetch_price_logs = PythonOperator(
task_id='fetch_price_logs',
python_callable=fetch_price_logs,
provide_context=True,
)
# compute demand from interactions
t_compute_demand = PythonOperator(
task_id='compute_demand',
python_callable=compute_demand,
provide_context=True,
)
# aggregate price logs
t_aggregate_prices = PythonOperator(
task_id='aggregate_price_logs',
python_callable=aggregate_price_logs,
provide_context=True,
)
# join demand and prices
t_join_features = PythonOperator(
task_id='join_product_features',
python_callable=join_product_features,
provide_context=True,
)
# apply surge pricing
t_surge_pricing = PythonOperator(
task_id='apply_surge_pricing',
python_callable=apply_surge_pricing,
provide_context=True,
)
# publish to registry
t_publish = PythonOperator(
task_id='publish_results',
python_callable=publish_results,
provide_context=True,
)
# dependency graph: parallel fetch -> process -> join -> surge -> publish
t_fetch_interactions >> t_compute_demand
t_fetch_price_logs >> t_aggregate_prices
[t_compute_demand, t_aggregate_prices] >> t_join_features >> t_surge_pricing >> t_publish