mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: training pipeline + tensorboard
This commit is contained in:
115
experiments/airflow/dags/ml_training_pipeline.py
Normal file
115
experiments/airflow/dags/ml_training_pipeline.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from airflow import DAG, Dataset
|
||||
from airflow.decorators import task
|
||||
from airflow.utils.dates import days_ago
|
||||
from datetime import timedelta
|
||||
import pandas as pd
|
||||
import logging
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
sys.path.insert(0, '/opt/airflow')
|
||||
|
||||
from procesing.context import PipelineContext
|
||||
from procesing.providers import SupabaseProvider, BackendAPIProvider
|
||||
from procesing.steps import (
|
||||
FetchInteractionsStep,
|
||||
ValidateDataStep,
|
||||
ExtractSessionFeaturesStep,
|
||||
JoinLabelsStep,
|
||||
)
|
||||
|
||||
TRAINING_DATASET = Dataset('phantom://ml/training-data')
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
'owner': 'phantom-research',
|
||||
'depends_on_past': False,
|
||||
'email_on_failure': False,
|
||||
'email_on_retry': False,
|
||||
'retries': 2,
|
||||
'retry_delay': timedelta(minutes=5),
|
||||
}
|
||||
|
||||
|
||||
class CompositeProvider(SupabaseProvider, BackendAPIProvider):
|
||||
def __init__(self):
|
||||
SupabaseProvider.__init__(self)
|
||||
BackendAPIProvider.__init__(self)
|
||||
|
||||
|
||||
def _get_context(store_mode: str = 'hotel') -> PipelineContext:
|
||||
return PipelineContext(provider=CompositeProvider(), store_mode=store_mode)
|
||||
|
||||
|
||||
with DAG(
|
||||
'ml_training_pipeline',
|
||||
default_args=DEFAULT_ARGS,
|
||||
description='ML training data pipeline: fetch -> validate -> extract features -> label -> publish',
|
||||
schedule=None,
|
||||
start_date=days_ago(1),
|
||||
catchup=False,
|
||||
max_active_runs=1,
|
||||
tags=['ml', 'training', 'features', 'research'],
|
||||
) as dag:
|
||||
|
||||
@task
|
||||
def fetch_interactions(**kwargs) -> bytes:
|
||||
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||
df = FetchInteractionsStep(ctx).transform(None)
|
||||
logging.info(f"Fetched {len(df)} interactions, {df['sessionId'].nunique()} sessions")
|
||||
return pickle.dumps(df)
|
||||
|
||||
@task
|
||||
def validate_data(raw_data: bytes, **kwargs) -> bytes:
|
||||
df = pickle.loads(raw_data)
|
||||
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||
validated = ValidateDataStep(ctx).transform(df)
|
||||
report = ctx.get_cached('validation_report') or {}
|
||||
logging.info(f"Validation: {report.get('status')}, {report.get('sessions', 0)} sessions")
|
||||
return pickle.dumps(validated)
|
||||
|
||||
@task
|
||||
def extract_session_features(validated_data: bytes, **kwargs) -> bytes:
|
||||
df = pickle.loads(validated_data)
|
||||
if df.empty:
|
||||
logging.warning("Empty input, skipping feature extraction")
|
||||
return pickle.dumps(pd.DataFrame())
|
||||
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||
features = ExtractSessionFeaturesStep(ctx).transform(df)
|
||||
logging.info(f"Extracted {len(features.columns)} features for {len(features)} sessions")
|
||||
return pickle.dumps(features)
|
||||
|
||||
@task
|
||||
def join_labels(features_data: bytes, **kwargs) -> bytes:
|
||||
features_df = pickle.loads(features_data)
|
||||
if features_df.empty:
|
||||
logging.warning("Empty features, skipping label join")
|
||||
return pickle.dumps(pd.DataFrame())
|
||||
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||
labeled = JoinLabelsStep(ctx).transform(features_df)
|
||||
n_agents = labeled['is_agent'].sum() if 'is_agent' in labeled.columns else 0
|
||||
logging.info(f"Labeled {len(labeled)} sessions: {n_agents} agents")
|
||||
return pickle.dumps(labeled)
|
||||
|
||||
@task(outlets=[TRAINING_DATASET])
|
||||
def publish_training_data(labeled_data: bytes, **kwargs) -> dict:
|
||||
labeled_df = pickle.loads(labeled_data)
|
||||
if labeled_df.empty:
|
||||
return {'status': 'skipped', 'reason': 'empty_data'}
|
||||
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||
return {
|
||||
'status': 'success',
|
||||
'n_sessions': len(labeled_df),
|
||||
'n_features': len([c for c in labeled_df.columns if c not in ['sessionId', 'experimentId', 'is_agent']]),
|
||||
'store_mode': dag_conf.get('store_mode', 'hotel'),
|
||||
'timestamp': pd.Timestamp.now().isoformat(),
|
||||
}
|
||||
|
||||
raw = fetch_interactions()
|
||||
validated = validate_data(raw)
|
||||
features = extract_session_features(validated)
|
||||
labeled = join_labels(features)
|
||||
publish_training_data(labeled)
|
||||
Reference in New Issue
Block a user