Catchup airline (#31)

* chore: update provider and pricing snitch with agnostic system

* cloning pipelines per mode instance

* updating airline hero section

* fix: must keep airflow secretkey

* fix: fixture update to hotel not shop

* chore: refactored to factory design pattern of pipelines

* chore: clean up definition of composite class of providers
This commit is contained in:
Daniel Alves Rösel
2025-12-11 21:56:12 +01:00
committed by GitHub
parent d45b344264
commit ef98141ca8
10 changed files with 384 additions and 55 deletions

View File

@@ -123,6 +123,7 @@ services:
- AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__WEBSERVER__EXPOSE_CONFIG=true
- AIRFLOW__WEBSERVER__SECRET_KEY=${AIRFLOW_SECRET_KEY}
- KAFKA_HOST=kafka
- KAFKA_PORT=29092
- BACKEND_URL=http://backend:5000
@@ -158,6 +159,7 @@ services:
- AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true
- AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__WEBSERVER__SECRET_KEY=${AIRFLOW_SECRET_KEY}
- KAFKA_HOST=kafka
- KAFKA_PORT=29092
- BACKEND_URL=http://backend:5000

View File

@@ -0,0 +1,210 @@
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from datetime import timedelta
import pandas as pd
import logging
import sys
import pickle
sys.path.insert(0, '/opt/airflow')
from procesing.context import PipelineContext
from procesing.providers import SupabaseProvider, BackendAPIProvider
from procesing.steps import (
FetchInteractionsStep,
FetchPriceLogsStep,
ComputeDemandStep,
AggregatePriceLogsStep,
JoinProductFeaturesStep,
)
from procesing.pricers.simple import SimpleSurgePricer
DEFAULT_ARGS = {
'owner': 'phantom-research',
'depends_on_past': False,
'email_on_failure': False,
'email_on_retry': False,
'retries': 2,
'retry_delay': timedelta(minutes=5),
}
class CompositeProvider(SupabaseProvider, BackendAPIProvider):
def __init__(self):
SupabaseProvider.__init__(self)
BackendAPIProvider.__init__(self)
def _get_provider():
return CompositeProvider()
def _make_task_callables(store_mode: str):
"""Generate task callables bound to a specific store_mode."""
def get_context(**kwargs):
return PipelineContext(provider=_get_provider(), store_mode=store_mode)
def fetch_interactions(**kwargs):
ctx = get_context(**kwargs)
df = FetchInteractionsStep(ctx).transform(None)
kwargs['ti'].xcom_push(key='interactions_raw', value=pickle.dumps(df))
logging.info(f"[{store_mode}] Fetched {len(df)} interaction records")
return len(df)
def fetch_price_logs(**kwargs):
ctx = get_context(**kwargs)
df = FetchPriceLogsStep(ctx).transform(None)
kwargs['ti'].xcom_push(key='price_logs_raw', value=pickle.dumps(df))
logging.info(f"[{store_mode}] Fetched {len(df)} price records")
return len(df)
def compute_demand(**kwargs):
ti = kwargs['ti']
df = pickle.loads(ti.xcom_pull(key='interactions_raw'))
ctx = get_context(**kwargs)
demand_df = ComputeDemandStep(ctx).transform(df)
ti.xcom_push(key='demand_data', value=pickle.dumps(demand_df))
logging.info(f"[{store_mode}] Computed demand for {len(demand_df)} products")
return len(demand_df)
def aggregate_price_logs(**kwargs):
ti = kwargs['ti']
df = pickle.loads(ti.xcom_pull(key='price_logs_raw'))
ctx = get_context(**kwargs)
price_df = AggregatePriceLogsStep(ctx).transform(df)
ti.xcom_push(key='price_data', value=pickle.dumps(price_df))
logging.info(f"[{store_mode}] Aggregated price logs for {len(price_df)} products")
return len(price_df)
def join_product_features(**kwargs):
ti = kwargs['ti']
demand_df = pickle.loads(ti.xcom_pull(key='demand_data'))
price_df = pickle.loads(ti.xcom_pull(key='price_data'))
ctx = get_context(**kwargs)
joined_df = JoinProductFeaturesStep(ctx).transform((demand_df, price_df))
ti.xcom_push(key='product_features', value=pickle.dumps(joined_df))
logging.info(f"[{store_mode}] Joined features for {len(joined_df)} products")
return len(joined_df)
def apply_surge_pricing(**kwargs):
ti = kwargs['ti']
product_features = pickle.loads(ti.xcom_pull(key='product_features'))
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
data = product_features.rename(columns={'demand_score': 'demand'})
surge_pricer = SimpleSurgePricer(
high_threshold=dag_conf.get('high_threshold', 10),
low_threshold=dag_conf.get('low_threshold', 2),
surge_multiplier=dag_conf.get('surge_multiplier', 1.2),
discount_multiplier=dag_conf.get('discount_multiplier', 0.9)
)
surge_pricer.fit(data)
data['optimal_price'] = surge_pricer.predict()
prices_df = data[['productId', 'price', 'base_price', 'optimal_price', 'demand']].rename(columns={
'price': 'current_price', 'demand': 'demand_score'
})
ti.xcom_push(key='predicted_prices', value=pickle.dumps(prices_df))
logging.info(f"[{store_mode}] Applied surge pricing for {len(prices_df)} products")
return len(prices_df)
def publish_results(**kwargs):
ti = kwargs['ti']
prices_df = pickle.loads(ti.xcom_pull(key='predicted_prices'))
from lib.model_registry import ModelRegistry
registry = ModelRegistry()
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
metadata = {
'timestamp': pd.Timestamp.now().isoformat(),
'store_mode': store_mode,
'dag_run_id': kwargs['dag_run'].run_id if kwargs.get('dag_run') else 'manual',
'pricing_method': 'surge',
'high_threshold': dag_conf.get('high_threshold', 10),
'low_threshold': dag_conf.get('low_threshold', 2),
'surge_multiplier': dag_conf.get('surge_multiplier', 1.2),
'discount_multiplier': dag_conf.get('discount_multiplier', 0.9)
}
registry.publish_prices(prices_df, model_name=f'{store_mode}_latest', metadata=metadata)
logging.info(f"[{store_mode}] Published surge pricing for {len(prices_df)} products")
return {
'n_products': len(prices_df),
'registry_status': 'success',
'store_mode': store_mode,
'mean_demand': float(prices_df['demand_score'].mean()) if 'demand_score' in prices_df.columns else None
}
return {
'fetch_interactions': fetch_interactions,
'fetch_price_logs': fetch_price_logs,
'compute_demand': compute_demand,
'aggregate_price_logs': aggregate_price_logs,
'join_product_features': join_product_features,
'apply_surge_pricing': apply_surge_pricing,
'publish_results': publish_results,
}
def create_surge_pricing_dag(store_mode: str) -> DAG:
"""Factory: generates a surge pricing DAG for a given store_mode."""
callables = _make_task_callables(store_mode)
dag = DAG(
f'surge_pricing_{store_mode}',
default_args=DEFAULT_ARGS,
description=f'Surge pricing pipeline for {store_mode} store mode',
schedule_interval='*/15 * * * *',
start_date=days_ago(1),
catchup=False,
max_active_runs=1,
tags=['pricing', 'surge', 'research', store_mode],
)
with dag:
t_fetch_interactions = PythonOperator(
task_id='fetch_interactions',
python_callable=callables['fetch_interactions'],
provide_context=True,
)
t_fetch_price_logs = PythonOperator(
task_id='fetch_price_logs',
python_callable=callables['fetch_price_logs'],
provide_context=True,
)
t_compute_demand = PythonOperator(
task_id='compute_demand',
python_callable=callables['compute_demand'],
provide_context=True,
)
t_aggregate_prices = PythonOperator(
task_id='aggregate_price_logs',
python_callable=callables['aggregate_price_logs'],
provide_context=True,
)
t_join_features = PythonOperator(
task_id='join_product_features',
python_callable=callables['join_product_features'],
provide_context=True,
)
t_surge_pricing = PythonOperator(
task_id='apply_surge_pricing',
python_callable=callables['apply_surge_pricing'],
provide_context=True,
)
t_publish = PythonOperator(
task_id='publish_results',
python_callable=callables['publish_results'],
provide_context=True,
)
t_fetch_interactions >> t_compute_demand
t_fetch_price_logs >> t_aggregate_prices
[t_compute_demand, t_aggregate_prices] >> t_join_features >> t_surge_pricing >> t_publish
return dag
# instantiate DAGs for Airflow to discover
dag_airline = create_surge_pricing_dag('airline')
dag_hotel = create_surge_pricing_dag('hotel')

View File

@@ -131,7 +131,7 @@ if __name__ == '__main__':
# example run
context = PipelineContext(
provider=HistoricalProvider(),
store_mode='hotel',
store_mode='airline',
)
product_features, prices = full_pipeline(context)

View File

@@ -18,10 +18,17 @@ class SupabaseProvider(DataProvider):
self.supabase: Client = create_client(self.supabase_url, self.supabase_key)
def fetch_products(self, store_mode: str) -> pd.DataFrame:
resp = self.supabase.table(f'{store_mode}_products').select(
"id, room_type, date_index, metadata, availability"
).execute()
return pd.DataFrame(resp.data) if resp.data else pd.DataFrame()
# hotel uses room_type, airline uses flight_type; select all and normalize
resp = self.supabase.table(f'{store_mode}_products').select("*").execute()
if not resp.data:
return pd.DataFrame()
df = pd.DataFrame(resp.data)
# normalize type column: hotel has room_type, airline has flight_type
if 'room_type' in df.columns:
df['product_type'] = df['room_type']
elif 'flight_type' in df.columns:
df['product_type'] = df['flight_type']
return df
def fetch_experiments(self, experiment_ids: List[str]) -> pd.DataFrame:
if not experiment_ids:

View File

@@ -2,7 +2,7 @@ import pandas as pd
from procesing.steps.base import BaseContextStep
class FetchInteractionsStep(BaseContextStep):
"""Fetch raw interaction data from Kafka topic with optional time filtering"""
"""Fetch raw interaction data from Kafka topic with optional time and store_mode filtering"""
def __init__(self, context, lookback: str = None):
super().__init__(context)
@@ -24,6 +24,10 @@ class FetchInteractionsStep(BaseContextStep):
# drop all where page has /admin/
df = df[~df['page'].str.contains('/admin/', na=False)]
# filter by store_mode from context
if 'storeMode' in df.columns:
df = df[df['storeMode'] == self.context.store_mode]
# Remap dateIndex if present
if 'metadata_dateIndex' in df.columns:
df['dateIndex'] = df['metadata_dateIndex'].astype('Int64')
@@ -38,7 +42,7 @@ class FetchInteractionsStep(BaseContextStep):
class FetchPriceLogsStep(BaseContextStep):
"""Fetch price log data from Kafka topic with optional time filtering"""
"""Fetch price log data from Kafka topic with optional time and store_mode filtering"""
def __init__(self, context, lookback: str = None):
super().__init__(context)
@@ -50,6 +54,10 @@ class FetchPriceLogsStep(BaseContextStep):
if df.empty:
return df
# filter by store_mode from context
if 'storeMode' in df.columns:
df = df[df['storeMode'] == self.context.store_mode]
# Apply time filtering if lookback specified
if self.lookback and 'ts' in df.columns:
df['ts'] = pd.to_datetime(df['ts'])

View File

@@ -144,7 +144,7 @@ def mock_price_logs_raw_kafka():
'price': 162.47,
'sessionId': 'd423ce8a-77aa-4c9a-94d4-d1adddcc3472',
'experimentId': '53aefd07-f66a-4d7f-ba8b-7ea1fc562d35',
'storeMode': 'shop',
'storeMode': 'hotel',
'ts': '2025-11-25T21:05:57.967Z'
}
}
@@ -157,7 +157,7 @@ def mock_price_logs_raw_kafka():
'price': 743.49,
'sessionId': 'd423ce8a-77aa-4c9a-94d4-d1adddcc3472',
'experimentId': '53aefd07-f66a-4d7f-ba8b-7ea1fc562d35',
'storeMode': 'shop',
'storeMode': 'hotel',
'ts': '2025-11-25T21:05:57.993Z'
}
}
@@ -170,7 +170,7 @@ def mock_price_logs_raw_kafka():
'price': 163.87,
'sessionId': 'd423ce8a-77aa-4c9a-94d4-d1adddcc3472',
'experimentId': '53aefd07-f66a-4d7f-ba8b-7ea1fc562d35',
'storeMode': 'shop',
'storeMode': 'hotel',
'ts': '2025-11-25T21:05:58.009Z'
}
}
@@ -183,7 +183,7 @@ def mock_price_logs_raw_kafka():
'price': 397.46,
'sessionId': 'd423ce8a-77aa-4c9a-94d4-d1adddcc3472',
'experimentId': '53aefd07-f66a-4d7f-ba8b-7ea1fc562d35',
'storeMode': 'shop',
'storeMode': 'hotel',
'ts': '2025-11-25T21:05:58.049Z'
}
}
@@ -196,7 +196,7 @@ def mock_price_logs_raw_kafka():
'price': 401.66,
'sessionId': 'd423ce8a-77aa-4c9a-94d4-d1adddcc3472',
'experimentId': '53aefd07-f66a-4d7f-ba8b-7ea1fc562d35',
'storeMode': 'shop',
'storeMode': 'hotel',
'ts': '2025-11-25T21:06:08.864Z'
}
}
@@ -222,7 +222,7 @@ def mock_experiments():
'created_at': pd.to_datetime(['2025-11-25T20:00:00Z', '2025-11-26T10:00:00Z']),
'subject_name': ['Session A', 'Session B'],
'xp_human_only': [True, False],
'xp_market_mode': ['hotel', 'shop'],
'xp_market_mode': ['hotel', 'airline'],
'xp_task_id': [None, None]
})

View File

@@ -2,10 +2,20 @@
import { useState, FormEvent } from 'react';
import { useRouter } from 'next/navigation';
import { Button, Label, Input, DateInput, RadioGroup, Dropdown, DropdownCounter } from '@/components/ui';
import { Button, Label, DateInput, Dropdown, DropdownCounter, SelectDropdown, SelectOption } from '@/components/ui';
import { dateToDaysFromToday } from '@/lib/airline-utils';
type TripType = 'roundtrip' | 'oneway' | 'multicity';
const CITIES: SelectOption[] = [
{ value: 'JFK', label: 'New York (JFK)', sublabel: 'John F. Kennedy International' },
{ value: 'LAX', label: 'Los Angeles (LAX)', sublabel: 'Los Angeles International' },
{ value: 'ORD', label: 'Chicago (ORD)', sublabel: "O'Hare International" },
{ value: 'MIA', label: 'Miami (MIA)', sublabel: 'Miami International' },
{ value: 'SFO', label: 'San Francisco (SFO)', sublabel: 'San Francisco International' },
{ value: 'SEA', label: 'Seattle (SEA)', sublabel: 'Seattle-Tacoma International' },
{ value: 'ATL', label: 'Atlanta (ATL)', sublabel: 'Hartsfield-Jackson International' },
{ value: 'DFW', label: 'Dallas (DFW)', sublabel: 'Dallas/Fort Worth International' },
];
const PlaneIcon = () => (
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
@@ -22,11 +32,9 @@ const LocationIcon = () => (
export default function AirlineHero() {
const router = useRouter();
const [tripType, setTripType] = useState<TripType>('roundtrip');
const [origin, setOrigin] = useState('');
const [destination, setDestination] = useState('');
const [departDate, setDepartDate] = useState('');
const [returnDate, setReturnDate] = useState('');
const [passengers, setPassengers] = useState({ adults: 1, children: 0, infants: 0 });
const handleSearch = (e: FormEvent) => {
@@ -40,8 +48,6 @@ export default function AirlineHero() {
if (origin) params.set('origin', origin);
if (destination) params.set('destination', destination);
if (tripType !== 'roundtrip') params.set('tripType', tripType);
if (returnDate && tripType === 'roundtrip') params.set('returnDate', returnDate);
params.set('adults', passengers.adults.toString());
params.set('children', passengers.children.toString());
@@ -66,28 +72,15 @@ export default function AirlineHero() {
<div className="search-form">
<form onSubmit={handleSearch}>
<div className="mb-6">
<RadioGroup
name="tripType"
value={tripType}
onChange={setTripType}
options={[
{ value: 'roundtrip', label: 'Round-trip' },
{ value: 'oneway', label: 'One-way' },
{ value: 'multicity', label: 'Multi-city' },
]}
/>
</div>
<div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-4">
<div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-4">
<div>
<Label htmlFor="origin">From</Label>
<Input
type="text"
<SelectDropdown
id="origin"
value={origin}
onChange={(e) => setOrigin(e.target.value)}
placeholder="Airport or city"
onChange={setOrigin}
options={CITIES}
placeholder="Select origin"
icon={<PlaneIcon />}
required
/>
@@ -95,12 +88,12 @@ export default function AirlineHero() {
<div>
<Label htmlFor="destination">To</Label>
<Input
type="text"
<SelectDropdown
id="destination"
value={destination}
onChange={(e) => setDestination(e.target.value)}
placeholder="Airport or city"
onChange={setDestination}
options={CITIES}
placeholder="Select destination"
icon={<LocationIcon />}
required
/>
@@ -115,20 +108,6 @@ export default function AirlineHero() {
required
/>
</div>
<div>
<Label htmlFor="returnDate">Return</Label>
{tripType === 'roundtrip' ? (
<DateInput
id="returnDate"
value={returnDate}
onChange={(e) => setReturnDate(e.target.value)}
required
/>
) : (
<DateInput id="returnDate" disabled />
)}
</div>
</div>
<div className="grid grid-cols-4 sm:grid-cols-3 lg:grid-cols-4 gap-4 mt-4">

View File

@@ -0,0 +1,119 @@
'use client';
import { useState, useRef, useEffect, ReactNode } from 'react';
export interface SelectOption {
value: string;
label: string;
sublabel?: string;
}
interface SelectDropdownProps {
value: string;
onChange: (value: string) => void;
options: SelectOption[];
placeholder?: string;
icon?: ReactNode;
required?: boolean;
id?: string;
}
export default function SelectDropdown({
value,
onChange,
options,
placeholder = 'Select...',
icon,
required,
id,
}: SelectDropdownProps) {
const [open, setOpen] = useState(false);
const [filter, setFilter] = useState('');
const ref = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
useEffect(() => {
const handleClick = (e: MouseEvent) => {
if (ref.current && !ref.current.contains(e.target as Node)) {
setOpen(false);
setFilter('');
}
};
document.addEventListener('mousedown', handleClick);
return () => document.removeEventListener('mousedown', handleClick);
}, []);
const selectedOption = options.find((o) => o.value === value);
const filtered = options.filter(
(o) =>
o.label.toLowerCase().includes(filter.toLowerCase()) ||
o.value.toLowerCase().includes(filter.toLowerCase()) ||
o.sublabel?.toLowerCase().includes(filter.toLowerCase())
);
const handleSelect = (opt: SelectOption) => {
onChange(opt.value);
setOpen(false);
setFilter('');
};
return (
<div className="relative" ref={ref}>
<div
className="input-field flex items-center gap-2 cursor-pointer box-border"
onClick={() => {
setOpen(true);
setTimeout(() => inputRef.current?.focus(), 0);
}}
>
{icon && <span className="text-[var(--text-secondary)]">{icon}</span>}
{open ? (
<input
ref={inputRef}
type="text"
id={id}
value={filter}
onChange={(e) => setFilter(e.target.value)}
placeholder={placeholder}
className="flex-1 bg-transparent outline-none text-sm text-[var(--text-primary)]"
/>
) : (
<span className={`flex-1 text-sm ${value ? 'text-[var(--text-primary)]' : 'text-[var(--text-secondary)]'}`}>
{selectedOption ? selectedOption.label : placeholder}
</span>
)}
<svg
className={`w-4 h-4 text-[var(--text-secondary)] transition-transform ${open ? 'rotate-180' : ''}`}
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M19 9l-7 7-7-7" />
</svg>
</div>
{open && (
<div className="absolute z-20 mt-1 w-full bg-[var(--bg-primary)] border-2 border-[var(--accent-primary)] rounded-md shadow-lg max-h-60 overflow-y-auto">
{filtered.length === 0 ? (
<div className="px-4 py-3 text-sm text-[var(--text-secondary)]">No results</div>
) : (
filtered.map((opt) => (
<div
key={opt.value}
onClick={() => handleSelect(opt)}
className={`px-4 py-2 cursor-pointer transition-colors hover:bg-[var(--accent-primary-light)] ${
opt.value === value ? 'bg-[var(--accent-primary-light)]' : ''
}`}
>
<div className="text-sm font-medium text-[var(--text-primary)]">{opt.label}</div>
{opt.sublabel && <div className="text-xs text-[var(--text-secondary)]">{opt.sublabel}</div>}
</div>
))
)}
</div>
)}
{required && !value && (
<input type="text" required className="sr-only" tabIndex={-1} value="" onChange={() => {}} />
)}
</div>
);
}

View File

@@ -5,3 +5,5 @@ export { default as DateInput } from './DateInput';
export { default as RadioGroup } from './RadioGroup';
export { default as Dropdown, DropdownCounter } from './Dropdown';
export { default as Navigation } from './Navigation';
export { default as SelectDropdown } from './SelectDropdown';
export type { SelectOption } from './SelectDropdown';

View File

@@ -278,6 +278,8 @@
padding: 12px;
transition: border-color 0.2s ease;
width: 100%;
min-height: 48px;
box-sizing: border-box;
}
[data-mode="airline"] .input-field:focus {