chore: enables cross comm pickling with fully e2e pipeline compilation

This commit is contained in:
2025-11-28 14:05:39 +01:00
parent 505c4fcd42
commit 33c20ec715
2 changed files with 22 additions and 28 deletions

View File

@@ -96,6 +96,7 @@ services:
- AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@postgres/airflow - AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@postgres/airflow
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=} - AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=}
- AIRFLOW__CORE__LOAD_EXAMPLES=false - AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- _AIRFLOW_DB_MIGRATE=true - _AIRFLOW_DB_MIGRATE=true
- _AIRFLOW_WWW_USER_CREATE=true - _AIRFLOW_WWW_USER_CREATE=true
- _AIRFLOW_WWW_USER_USERNAME=admin - _AIRFLOW_WWW_USER_USERNAME=admin
@@ -126,6 +127,7 @@ services:
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=} - AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=}
- AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true - AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true
- AIRFLOW__CORE__LOAD_EXAMPLES=false - AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__WEBSERVER__EXPOSE_CONFIG=true - AIRFLOW__WEBSERVER__EXPOSE_CONFIG=true
- KAFKA_HOST=kafka - KAFKA_HOST=kafka
- KAFKA_PORT=29092 - KAFKA_PORT=29092
@@ -167,6 +169,7 @@ services:
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=} - AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY:-fb4E5zWb8hh7WKN7tXUkWP0r5nTcN1nKZGh1h0N3x6Q=}
- AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true - AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true
- AIRFLOW__CORE__LOAD_EXAMPLES=false - AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- KAFKA_HOST=kafka - KAFKA_HOST=kafka
- KAFKA_PORT=29092 - KAFKA_PORT=29092
- BACKEND_URL=http://backend:5000 - BACKEND_URL=http://backend:5000

View File

@@ -8,12 +8,12 @@ import sys
import pickle import pickle
import io import io
# add procesing module to path (mounted at /opt/airflow/procesing in container) # add parent dir to path so procesing package can be imported
sys.path.insert(0, '/opt/airflow/procesing') sys.path.insert(0, '/opt/airflow')
from context import PipelineContext from procesing.context import PipelineContext
from providers import SupabaseProvider, BackendAPIProvider from procesing.providers import SupabaseProvider, BackendAPIProvider
from steps import ( from procesing.steps import (
FetchInteractionsStep, FetchInteractionsStep,
FetchPriceLogsStep, FetchPriceLogsStep,
CreatePriceBucketsStep, CreatePriceBucketsStep,
@@ -63,7 +63,7 @@ def fetch_interactions(**kwargs):
step = FetchInteractionsStep(context) step = FetchInteractionsStep(context)
df = step.transform(None) df = step.transform(None)
kwargs['ti'].xcom_push(key='interactions_raw', value=df.to_json()) kwargs['ti'].xcom_push(key='interactions_raw', value=pickle.dumps(df))
logging.info(f"Fetched {len(df)} interaction records") logging.info(f"Fetched {len(df)} interaction records")
return len(df) return len(df)
@@ -73,43 +73,40 @@ def fetch_price_logs(**kwargs):
step = FetchPriceLogsStep(context) step = FetchPriceLogsStep(context)
df = step.transform(None) df = step.transform(None)
kwargs['ti'].xcom_push(key='price_logs_raw', value=df.to_json()) kwargs['ti'].xcom_push(key='price_logs_raw', value=pickle.dumps(df))
logging.info(f"Fetched {len(df)} price records") logging.info(f"Fetched {len(df)} price records")
return len(df) return len(df)
def create_price_buckets(**kwargs): def create_price_buckets(**kwargs):
"""Task: Create price buckets for interactions""" """Task: Create price buckets for interactions"""
ti = kwargs['ti'] ti = kwargs['ti']
interactions_json = ti.xcom_pull(key='interactions_raw') df = pickle.loads(ti.xcom_pull(key='interactions_raw'))
df = pd.read_json(io.StringIO(interactions_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = CreatePriceBucketsStep(context) step = CreatePriceBucketsStep(context)
df = step.transform(df) df = step.transform(df)
ti.xcom_push(key='interactions_bucketed', value=df.to_json()) ti.xcom_push(key='interactions_bucketed', value=pickle.dumps(df))
logging.info(f"Created price buckets for {len(df)} interactions") logging.info(f"Created price buckets for {len(df)} interactions")
return len(df) return len(df)
def augment_event_names(**kwargs): def augment_event_names(**kwargs):
"""Task: Augment event names with product and price schema""" """Task: Augment event names with product and price schema"""
ti = kwargs['ti'] ti = kwargs['ti']
interactions_json = ti.xcom_pull(key='interactions_bucketed') df = pickle.loads(ti.xcom_pull(key='interactions_bucketed'))
df = pd.read_json(io.StringIO(interactions_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = AugmentEventNamesStep(context) step = AugmentEventNamesStep(context)
df = step.transform(df) df = step.transform(df)
ti.xcom_push(key='interactions_final', value=df.to_json()) ti.xcom_push(key='interactions_final', value=pickle.dumps(df))
logging.info(f"Augmented event names for {len(df)} interactions") logging.info(f"Augmented event names for {len(df)} interactions")
return len(df) return len(df)
def chunk_interactions(**kwargs): def chunk_interactions(**kwargs):
"""Task: Chunk interactions into time windows""" """Task: Chunk interactions into time windows"""
ti = kwargs['ti'] ti = kwargs['ti']
interactions_json = ti.xcom_pull(key='interactions_final') df = pickle.loads(ti.xcom_pull(key='interactions_final'))
df = pd.read_json(io.StringIO(interactions_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = ChunkByTimeWindowStep(context) step = ChunkByTimeWindowStep(context)
@@ -135,8 +132,7 @@ def compute_demand(**kwargs):
def aggregate_price_logs(**kwargs): def aggregate_price_logs(**kwargs):
"""Task: Aggregate price logs into time windows (VECTORIZED)""" """Task: Aggregate price logs into time windows (VECTORIZED)"""
ti = kwargs['ti'] ti = kwargs['ti']
price_logs_json = ti.xcom_pull(key='price_logs_raw') df = pickle.loads(ti.xcom_pull(key='price_logs_raw'))
df = pd.read_json(io.StringIO(price_logs_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = AggregatePriceLogsStep(context) step = AggregatePriceLogsStep(context)
@@ -156,7 +152,7 @@ def compute_elasticity(**kwargs):
step = ComputeElasticityStep(context) step = ComputeElasticityStep(context)
elasticity_df = step.transform((demand_chunks, price_chunks)) elasticity_df = step.transform((demand_chunks, price_chunks))
ti.xcom_push(key='elasticity_results', value=elasticity_df.to_json()) ti.xcom_push(key='elasticity_results', value=pickle.dumps(elasticity_df))
logging.info(f"Computed elasticity for {len(elasticity_df)} products") logging.info(f"Computed elasticity for {len(elasticity_df)} products")
return { return {
@@ -168,8 +164,7 @@ def compute_elasticity(**kwargs):
def build_state_space(**kwargs): def build_state_space(**kwargs):
"""Task: Build state space from elasticity""" """Task: Build state space from elasticity"""
ti = kwargs['ti'] ti = kwargs['ti']
elasticity_json = ti.xcom_pull(key='elasticity_results') elasticity_df = pickle.loads(ti.xcom_pull(key='elasticity_results'))
elasticity_df = pd.read_json(io.StringIO(elasticity_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = BuildStateSpaceStep(context) step = BuildStateSpaceStep(context)
@@ -182,8 +177,7 @@ def build_state_space(**kwargs):
def fit_pricing_function(**kwargs): def fit_pricing_function(**kwargs):
"""Task: Fit pricing function using elasticity""" """Task: Fit pricing function using elasticity"""
ti = kwargs['ti'] ti = kwargs['ti']
elasticity_json = ti.xcom_pull(key='elasticity_results') elasticity_df = pickle.loads(ti.xcom_pull(key='elasticity_results'))
elasticity_df = pd.read_json(io.StringIO(elasticity_json))
context = get_context(**kwargs) context = get_context(**kwargs)
step = FitPricingFunctionStep(context) step = FitPricingFunctionStep(context)
@@ -203,18 +197,15 @@ def predict_prices(**kwargs):
step = PredictPricesStep(context) step = PredictPricesStep(context)
prices_df = step.transform((pricer, state_space)) prices_df = step.transform((pricer, state_space))
ti.xcom_push(key='predicted_prices', value=prices_df.to_json()) ti.xcom_push(key='predicted_prices', value=pickle.dumps(prices_df))
logging.info(f"Predicted prices for {len(prices_df)} products") logging.info(f"Predicted prices for {len(prices_df)} products")
return len(prices_df) return len(prices_df)
def publish_results(**kwargs): def publish_results(**kwargs):
"""Task: Publish elasticity and pricing results to model registry""" """Task: Publish elasticity and pricing results to model registry"""
ti = kwargs['ti'] ti = kwargs['ti']
elasticity_json = ti.xcom_pull(key='elasticity_results') elasticity_df = pickle.loads(ti.xcom_pull(key='elasticity_results'))
prices_json = ti.xcom_pull(key='predicted_prices') prices_df = pickle.loads(ti.xcom_pull(key='predicted_prices'))
elasticity_df = pd.read_json(io.StringIO(elasticity_json))
prices_df = pd.read_json(io.StringIO(prices_json))
sys.path.insert(0, '/opt/airflow') sys.path.insert(0, '/opt/airflow')
from lib.model_registry import ModelRegistry from lib.model_registry import ModelRegistry