chore: refactored and broke down components (braking

This commit is contained in:
2025-11-28 13:43:05 +01:00
parent f749bd749c
commit b38f2b0c66
15 changed files with 743 additions and 3 deletions

View File

@@ -189,7 +189,7 @@ def publish_results(**context):
# import registry and pricing modules
import sys
sys.path.insert(0, '/opt/airflow/procesing')
sys.path.insert(0, '/opt/airflow/procesing') # this is pretty janky
sys.path.insert(0, '/opt/airflow')
from lib.model_registry import ModelRegistry
@@ -243,8 +243,8 @@ def publish_results(**context):
with DAG(
'elasticity_pricing_pipeline',
default_args=default_args,
description='E2E pipeline: interactions demand elasticity pricing',
schedule_interval='*/5 * * * *', # every 5 minutes for real-time pricing
description='E2E pipeline: interactions -> demand -> elasticity -> pricing',
schedule_interval='*/15 * * * *', # every 5 minutes for real-time pricing
start_date=days_ago(1),
catchup=False,
max_active_runs=1,

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict
import pandas as pd
from .providers.base import DataProvider
class PipelineContext:
"""
Context for pipeline execution holding config, provider, and cached data.
Enables dependency injection and eliminates global state.
"""
def __init__(self,
provider: DataProvider,
store_mode: str,
window_size: str = '30s',
**config):
self.provider = provider
self.store_mode = store_mode
self.window_size = window_size
self.config = config
self._cache: Dict[str, Any] = {}
def get_cached(self, key: str, default=None):
return self._cache.get(key, default)
def cache(self, key: str, value):
self._cache[key] = value
return value
@property
def products(self) -> pd.DataFrame:
"""Lazy-load and cache product catalog, single fetch per pipeline run"""
if 'products' not in self._cache:
self._cache['products'] = self.provider.fetch_products(self.store_mode)
return self._cache['products']

View File

@@ -0,0 +1,5 @@
from .base import DataProvider
from .supabase import SupabaseProvider
from .backend import BackendAPIProvider
__all__ = ['DataProvider', 'SupabaseProvider', 'BackendAPIProvider']

View File

@@ -0,0 +1,19 @@
import os
import pandas as pd
import requests
from typing import List
from .base import DataProvider
class BackendAPIProvider(DataProvider):
"""Concrete backend API implementation"""
def __init__(self, backend_url: str = None):
self.backend_url = backend_url or os.getenv("BACKEND_URL", "http://localhost:5000")
def fetch_kafka_topic(self, topic: str) -> pd.DataFrame:
resp = requests.get(f"{self.backend_url}/api/kafka/dump?topic={topic}")
resp.raise_for_status()
data = resp.json()
if not data.get('success') or not data.get('data'):
return pd.DataFrame()
return pd.DataFrame(data['data'])

View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import List
import pandas as pd
class DataProvider(ABC):
"""Abstract interface for data access, enables DI and testing"""
@abstractmethod
def fetch_products(self, store_mode: str) -> pd.DataFrame:
"""Fetch product catalog for given store mode"""
pass
@abstractmethod
def fetch_experiments(self, experiment_ids: List[str]) -> pd.DataFrame:
"""Fetch experiment metadata for given IDs"""
pass
@abstractmethod
def fetch_kafka_topic(self, topic: str) -> pd.DataFrame:
"""Fetch data from Kafka topic via backend API"""
pass

View File

@@ -0,0 +1,33 @@
import os
import pandas as pd
import requests
from typing import List
from supabase import create_client, Client
from .base import DataProvider
class SupabaseProvider(DataProvider):
"""Concrete Supabase + backend API implementation"""
def __init__(self,
supabase_url: str = None,
supabase_key: str = None,):
self.supabase_url = supabase_url or os.getenv("NEXT_PUBLIC_SUPABASE_URL")
self.supabase_key = supabase_key or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY")
self.supabase: Client = create_client(self.supabase_url, self.supabase_key)
def fetch_products(self, store_mode: str) -> pd.DataFrame:
resp = self.supabase.table(f'{store_mode}_products').select(
"id, room_type, date_index, metadata, availability"
).execute()
return pd.DataFrame(resp.data) if resp.data else pd.DataFrame()
def fetch_experiments(self, experiment_ids: List[str]) -> pd.DataFrame:
if not experiment_ids:
return pd.DataFrame()
resp = self.supabase.table('experiments').select(
'id, subject_name, xp_human_only, xp_market_mode, xp_task_id, '
'task:tasks(task_name, task_description, task_def_of_done)'
).in_('id', experiment_ids).execute()
return pd.DataFrame(resp.data) if resp.data else pd.DataFrame()

View File

@@ -0,0 +1,27 @@
from .base import BaseContextStep
from .fetch import FetchInteractionsStep, FetchPriceLogsStep, FetchExperimentsStep
from .join import JoinExperimentsStep
from .augment import CreatePriceBucketsStep, AugmentEventNamesStep
from .chunk import ChunkByTimeWindowStep
from .demand import ComputeDemandStep, ComputeDemandForChunksStep
from .elasticity import AggregatePriceLogsStep, ComputeElasticityStep
from .pricing import StateSpace, BuildStateSpaceStep, FitPricingFunctionStep, PredictPricesStep
__all__ = [
'BaseContextStep',
'FetchInteractionsStep',
'FetchPriceLogsStep',
'FetchExperimentsStep',
'JoinExperimentsStep',
'CreatePriceBucketsStep',
'AugmentEventNamesStep',
'ChunkByTimeWindowStep',
'ComputeDemandStep',
'ComputeDemandForChunksStep',
'AggregatePriceLogsStep',
'ComputeElasticityStep',
'StateSpace',
'BuildStateSpaceStep',
'FitPricingFunctionStep',
'PredictPricesStep',
]

View File

@@ -0,0 +1,53 @@
import numpy as np
import pandas as pd
from .base import BaseContextStep
class CreatePriceBucketsStep(BaseContextStep):
"""Create price bucket labels from price data"""
def transform(self, df: pd.DataFrame):
if df.empty or 'metadata_price' not in df.columns:
df['price_bucket'] = ""
return df
n_buckets = self.context.config.get('n_price_buckets', 5)
if df['metadata_price'].notnull().sum() > 0:
try:
price_buckets = pd.qcut(
df['metadata_price'],
q=n_buckets,
labels=[f"PB_{i+1}" for i in range(n_buckets)],
duplicates='drop'
)
except ValueError:
# fallback for insufficient unique values
price_buckets = df['metadata_price'].apply(
lambda x: f"P_{int(x)}" if pd.notnull(x) else ""
)
else:
price_buckets = pd.Series([""] * len(df), index=df.index)
df['price_bucket'] = price_buckets
return df
class AugmentEventNamesStep(BaseContextStep):
"""Augment event names with product and price bucket schema"""
def transform(self, df: pd.DataFrame):
if df.empty:
return df
# Create schema: _productId@price_bucket
has_product = df.get('productId', pd.Series()).notnull()
has_bucket = df.get('price_bucket', pd.Series()).notnull()
df['metadata_schema'] = np.where(
has_product & has_bucket,
"_" + df['productId'].astype(str) + "@" + df['price_bucket'].astype(str),
""
)
df['eventName'] = df['eventName'] + df['metadata_schema']
return df

View File

@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from sklearn.base import BaseEstimator, TransformerMixin
from ..context import PipelineContext
class BaseContextStep(BaseEstimator, TransformerMixin, ABC):
"""
Base for all pipeline steps.
Each step is stateless, context-driven, and performs ONE transformation.
"""
def __init__(self, context: PipelineContext):
self.context = context
def fit(self, X=None, y=None):
"""Most steps don't need training"""
return self
@abstractmethod
def transform(self, X):
"""Transform input using context. Must be implemented by subclass."""
pass
def get_params(self, deep=True):
"""sklearn compatibility"""
return {'context': self.context}
def set_params(self, **params):
"""sklearn compatibility"""
if 'context' in params:
self.context = params['context']
return self

View File

@@ -0,0 +1,34 @@
import pandas as pd
from .base import BaseContextStep
class ChunkByTimeWindowStep(BaseContextStep):
"""
Chunk dataframe into time windows.
Returns list of dicts with window metadata.
"""
def transform(self, df: pd.DataFrame):
if df.empty:
return []
df = df.copy()
ts_col = self.context.config.get('ts_col', 'ts')
window_size = self.context.window_size
# ensure datetime
if not pd.api.types.is_datetime64_any_dtype(df[ts_col]):
df[ts_col] = pd.to_datetime(df[ts_col])
df = df.sort_values(ts_col)
df['_window'] = df[ts_col].dt.floor(window_size)
chunks = []
for idx, (window_start, group) in enumerate(df.groupby('_window')):
chunks.append({
'window_start': window_start,
'window_end': window_start + pd.Timedelta(window_size),
'window_idx': idx,
'data': group.drop(columns=['_window'])
})
return chunks

View File

@@ -0,0 +1,61 @@
import pandas as pd
from .base import BaseContextStep
class ComputeDemandStep(BaseContextStep):
"""
Compute demand vector for a single time window or dataframe.
Input: single chunk dict OR raw dataframe
Output: demand dataframe with [productId, demand_score]
"""
def transform(self, chunk):
# handle both chunk dict and raw dataframe
if isinstance(chunk, dict):
interactions = chunk['data']
window_meta = {k: v for k, v in chunk.items() if k != 'data'}
else:
interactions = chunk
window_meta = {}
products = self.context.products
unique_products = products['id'].unique()
# apply filters if configured
session_filter = self.context.config.get('session_filter')
experiment_filter = self.context.config.get('experiment_filter')
if session_filter and 'sessionId' in interactions.columns:
interactions = interactions[interactions['sessionId'] == session_filter]
if experiment_filter and 'experimentId' in interactions.columns:
interactions = interactions[interactions['experimentId'] == experiment_filter]
interactions_with_products = interactions.dropna(subset=['productId'])
if interactions_with_products.empty:
demand_df = pd.DataFrame({
'productId': unique_products,
'demand_score': 0
})
else:
# crosstab for simple demand count
demand_df = pd.crosstab(
interactions_with_products['productId'],
'count'
).reindex(unique_products, fill_value=0).reset_index()
demand_df.columns = ['productId', 'demand_score']
# attach window metadata if present
if window_meta:
return {**window_meta, 'demand_vector': demand_df}
return demand_df
class ComputeDemandForChunksStep(BaseContextStep):
"""Apply ComputeDemandStep to list of chunks"""
def transform(self, chunks: list):
if not chunks:
return []
demand_step = ComputeDemandStep(self.context)
return [demand_step.transform(chunk) for chunk in chunks]

View File

@@ -0,0 +1,253 @@
import numpy as np
import pandas as pd
from typing import Dict, List
from .base import BaseContextStep
class AggregatePriceLogsStep(BaseContextStep):
"""
Aggregate price logs into time windows using VECTORIZED operations.
Input: price_logs_df
Output: list of price chunks with [productId, price]
"""
def transform(self, price_logs_df: pd.DataFrame):
if price_logs_df.empty:
return []
df = price_logs_df.copy()
ts_col = self.context.config.get('ts_col', 'ts')
window_size = self.context.window_size
# ensure datetime
if not pd.api.types.is_datetime64_any_dtype(df[ts_col]):
df[ts_col] = pd.to_datetime(df[ts_col])
df = df.sort_values([ts_col, 'productId'])
products = self.context.products
unique_products = products['id'].unique()
# VECTORIZED: group by product, resample by time window, compute mean
df_indexed = df.set_index(ts_col)
windowed = (
df_indexed
.groupby('productId')['price']
.resample(window_size)
.mean()
.reset_index()
)
# forward fill missing windows (carry last known price)
windowed = windowed.sort_values([ts_col, 'productId'])
windowed['price'] = windowed.groupby('productId')['price'].ffill()
windowed = windowed.dropna(subset=['price'])
# group into chunks by window
chunks = []
for window_start, group in windowed.groupby(ts_col):
price_vector = group[['productId', 'price']].copy()
# fill missing products with last known price before this window
missing_products = set(unique_products) - set(price_vector['productId'])
if missing_products:
for pid in missing_products:
last_price = df_indexed[
(df_indexed['productId'] == pid) &
(df_indexed.index < window_start)
]['price']
if not last_price.empty:
price_vector = pd.concat([
price_vector,
pd.DataFrame({'productId': [pid], 'price': [last_price.iloc[-1]]})
], ignore_index=True)
if not price_vector.empty:
chunks.append({
'window_start': window_start,
'window_end': window_start + pd.Timedelta(window_size),
'price_vector': price_vector
})
return chunks
class ComputeElasticityStep(BaseContextStep):
"""
Compute price elasticity from demand and price chunks.
Input: (demand_chunks, price_chunks)
Output: elasticity_df [productId, elasticity, std_error, n_obs]
"""
def transform(self, chunk_tuple: tuple):
demand_chunks, price_chunks = chunk_tuple
method = self.context.config.get('elasticity_method', 'point')
min_obs = self.context.config.get('min_observations', 2)
products = self.context.products
all_product_ids = products['id'].unique()
# align chunks by window_start
aligned = self._align_chunks(demand_chunks, price_chunks)
if not aligned:
return pd.DataFrame({
'productId': all_product_ids,
'elasticity': 0.0,
'std_error': 0.0,
'n_obs': 0
})
# build time series per product
product_series = self._build_timeseries(aligned)
# compute elasticity per product
elasticities = []
for pid, series in product_series.items():
if len(series) < min_obs:
elasticities.append({
'productId': pid,
'elasticity': 0.0,
'std_error': 0.0,
'n_obs': len(series)
})
continue
elast = self._compute_elasticity(series, method)
elasticities.append({
'productId': pid,
'elasticity': elast['value'],
'std_error': elast.get('std_error', 0.0),
'n_obs': len(series)
})
result_df = pd.DataFrame(elasticities)
# fill missing products with zero elasticity
observed_pids = set(result_df['productId'])
missing_pids = [p for p in all_product_ids if p not in observed_pids]
if missing_pids:
missing_df = pd.DataFrame({
'productId': missing_pids,
'elasticity': 0.0,
'std_error': 0.0,
'n_obs': 0
})
result_df = pd.concat([result_df, missing_df], ignore_index=True)
return result_df
def _align_chunks(self, demand_chunks: List[Dict], price_chunks: List[Dict]):
"""Align demand and price chunks by window_start"""
price_lookup = {c['window_start']: c for c in price_chunks}
aligned = []
for dc in demand_chunks:
ws = dc['window_start']
if ws in price_lookup:
aligned.append({
'window_start': ws,
'window_end': dc['window_end'],
'demand': dc['demand_vector'],
'prices': price_lookup[ws]['price_vector']
})
return aligned
def _build_timeseries(self, aligned: List[Dict]):
"""Build time series [timestamp, price, quantity] per product"""
series_by_product = {}
for chunk in aligned:
merged = chunk['demand'].merge(chunk['prices'], on='productId', how='inner')
for _, row in merged.iterrows():
pid = row['productId']
if pid not in series_by_product:
series_by_product[pid] = []
series_by_product[pid].append({
'timestamp': chunk['window_start'],
'price': row['price'],
'quantity': row['demand_score']
})
return series_by_product
def _compute_elasticity(self, series: List[Dict], method: str):
"""Compute point or arc elasticity"""
prices = np.array([s['price'] for s in series])
quantities = np.array([s['quantity'] for s in series])
# filter out zero/negative values
valid = (prices > 0) & (quantities > 0)
if valid.sum() < 2:
return {'value': 0.0, 'std_error': 0.0}
prices = prices[valid]
quantities = quantities[valid]
if method == 'point':
return self._point_elasticity(prices, quantities)
elif method == 'arc':
return self._arc_elasticity(prices, quantities)
else:
raise ValueError(f"Unknown elasticity method: {method}")
def _point_elasticity(self, prices: np.ndarray, quantities: np.ndarray):
"""Point elasticity via log-log regression: log(Q) = a + b*log(P), elasticity = b"""
if len(prices) < 2:
return {'value': 0.0, 'std_error': 0.0}
log_p = np.log(prices)
log_q = np.log(quantities)
if log_p.std() == 0:
return {'value': 0.0, 'std_error': 0.0}
cov = np.cov(log_p, log_q)[0, 1]
var = np.var(log_p)
b = cov / var
# std error estimate
if len(prices) > 2:
residuals = log_q - (log_q.mean() + b * (log_p - log_p.mean()))
mse = (residuals ** 2).sum() / (len(prices) - 2)
se_b = np.sqrt(mse / (len(prices) * var))
else:
se_b = 0.0
return {'value': b, 'std_error': se_b}
def _arc_elasticity(self, prices: np.ndarray, quantities: np.ndarray):
"""Arc elasticity: average period-over-period elasticity"""
elasticities = []
for i in range(1, len(prices)):
p1, p2 = prices[i-1], prices[i]
q1, q2 = quantities[i-1], quantities[i]
p_avg = (p1 + p2) / 2
q_avg = (q1 + q2) / 2
if p_avg == 0 or q_avg == 0:
continue
delta_p = p2 - p1
delta_q = q2 - q1
if delta_p == 0:
continue
e = (delta_q / q_avg) / (delta_p / p_avg)
elasticities.append(e)
if not elasticities:
return {'value': 0.0, 'std_error': 0.0}
return {
'value': np.mean(elasticities),
'std_error': np.std(elasticities) / np.sqrt(len(elasticities))
}

View File

@@ -0,0 +1,46 @@
import pandas as pd
from .base import BaseContextStep
class FetchInteractionsStep(BaseContextStep):
"""Fetch raw interaction data from Kafka topic"""
def transform(self, X=None):
df = self.context.provider.fetch_kafka_topic('user-interactions')
if df.empty:
return df
# Explode metadata JSON column
if 'metadata' in df.columns:
df = df.join(
pd.json_normalize(df.pop('metadata'), sep='.').add_prefix('metadata_')
)
df = df.dropna(subset=['eventName'])
# Remap dateIndex if present
if 'metadata_dateIndex' in df.columns:
df['dateIndex'] = df['metadata_dateIndex'].astype('Int64')
return df
class FetchPriceLogsStep(BaseContextStep):
"""Fetch price log data from Kafka topic"""
def transform(self, X=None):
return self.context.provider.fetch_kafka_topic('price-logs')
class FetchExperimentsStep(BaseContextStep):
"""Fetch experiment metadata for given interaction data"""
def transform(self, interactions_df: pd.DataFrame):
if interactions_df.empty or 'experimentId' not in interactions_df.columns:
return pd.DataFrame()
exp_ids = interactions_df['experimentId'].dropna().unique().tolist()
if not exp_ids:
return pd.DataFrame()
return self.context.provider.fetch_experiments(exp_ids)

View File

@@ -0,0 +1,34 @@
import pandas as pd
from .base import BaseContextStep
class JoinExperimentsStep(BaseContextStep):
"""Join experiment metadata to interactions"""
def transform(self, data: tuple):
"""
Args:
data: (interactions_df, experiments_df)
Returns:
merged interactions dataframe
"""
interactions_df, experiments_df = data
if experiments_df.empty:
return interactions_df
# Flatten nested task field if present
if 'task' in experiments_df.columns and experiments_df['task'].notnull().any():
task_norm = pd.json_normalize(experiments_df['task'].dropna())
task_norm.index = experiments_df[experiments_df['task'].notnull()].index
experiments_df = experiments_df.drop('task', axis=1).join(task_norm, rsuffix='_task')
# Rename for clarity
experiments_df = experiments_df.rename(columns={
'id': 'experimentId',
'subject_name': 'exp_subject',
'xp_human_only': 'exp_human_only',
'xp_market_mode': 'exp_market_mode',
'xp_task_id': 'exp_task_id'
})
return interactions_df.merge(experiments_df, on='experimentId', how='left')

View File

@@ -0,0 +1,89 @@
import numpy as np
import pandas as pd
from .base import BaseContextStep
from ..pricing import ElasticityBasedPricingFunction
class StateSpace:
"""State representation for pricing functions"""
def __init__(self,
demand: np.ndarray,
prices: np.ndarray,
session_features: pd.DataFrame = None):
self.demand = demand
self.prices = prices
self.session_features = session_features if session_features is not None else pd.DataFrame()
class BuildStateSpaceStep(BaseContextStep):
"""
Build state space from elasticity and price data.
Input: elasticity_df
Output: StateSpace instance
"""
def transform(self, elasticity_df: pd.DataFrame):
products = self.context.products
# fetch current/base prices from product metadata
products_with_prices = products.copy()
if 'metadata' in products_with_prices.columns:
products_with_prices['base_price'] = products_with_prices['metadata'].apply(
lambda m: m.get('base_price', 0) if isinstance(m, dict) else 0
)
else:
products_with_prices['base_price'] = 0
# merge with elasticity
merged = products_with_prices[['id', 'base_price']].rename(
columns={'id': 'productId'}
).merge(
elasticity_df[['productId', 'elasticity']],
on='productId',
how='left'
).fillna({'elasticity': 0.0, 'base_price': 0.0})
return StateSpace(
demand=merged['elasticity'].values,
prices=merged['base_price'].values,
session_features=pd.DataFrame()
)
class FitPricingFunctionStep(BaseContextStep):
"""
Fit pricing function using elasticity data.
Input: elasticity_df
Output: fitted pricing function instance
"""
def transform(self, elasticity_df: pd.DataFrame):
from pricing import ElasticityBasedPricingFunction
pricing_class = self.context.config.get('pricing_function_class', ElasticityBasedPricingFunction)
pricing_params = self.context.config.get('pricing_function_params', {})
pricer = pricing_class(**pricing_params)
pricer.fit(elasticity_df)
return pricer
class PredictPricesStep(BaseContextStep):
"""
Predict optimal prices using fitted pricing function.
Input: (pricer, state_space)
Output: prices_df [productId, predicted_price]
"""
def transform(self, data: tuple):
pricer, state_space = data
products = self.context.products
product_ids = products['id'].values
predicted_prices = pricer.transform(state_space, product_ids)
return pd.DataFrame({
'productId': product_ids,
'predicted_price': predicted_prices
})