mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
34 initial discriminator of interaction data (#38)
* feat: training pipeline + tensorboard * tesnorboard forgot * chore: ml basic boilerplate * feat: naive architecture as start * eval setup * chore: parquet exporting of data * chore: updating requirements necesary * feat: separating modules and adding training logs paths * Update experiments/ml/train.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix: new path for runs * fix: undoing ai slop code * chore: modules and reqs --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a1916c966c
commit
f2271e368e
@@ -1,4 +1,15 @@
|
|||||||
services:
|
services:
|
||||||
|
|
||||||
|
tensorboard:
|
||||||
|
image: tensorflow/tensorflow:latest
|
||||||
|
container_name: "PHANTOM-tensorboard"
|
||||||
|
ports:
|
||||||
|
- "6006:6006"
|
||||||
|
volumes:
|
||||||
|
- ./experiments/ml/runs:/logs
|
||||||
|
command: tensorboard --logdir=/logs --host=0.0.0.0 --port=6006
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
container_name: "PHANTOM-backend"
|
container_name: "PHANTOM-backend"
|
||||||
build:
|
build:
|
||||||
|
|||||||
115
experiments/airflow/dags/ml_training_pipeline.py
Normal file
115
experiments/airflow/dags/ml_training_pipeline.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from airflow import DAG, Dataset
|
||||||
|
from airflow.decorators import task
|
||||||
|
from airflow.utils.dates import days_ago
|
||||||
|
from datetime import timedelta
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
sys.path.insert(0, '/opt/airflow')
|
||||||
|
|
||||||
|
from procesing.context import PipelineContext
|
||||||
|
from procesing.providers import SupabaseProvider, BackendAPIProvider
|
||||||
|
from procesing.steps import (
|
||||||
|
FetchInteractionsStep,
|
||||||
|
ValidateDataStep,
|
||||||
|
ExtractSessionFeaturesStep,
|
||||||
|
JoinLabelsStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
TRAINING_DATASET = Dataset('phantom://ml/training-data')
|
||||||
|
|
||||||
|
DEFAULT_ARGS = {
|
||||||
|
'owner': 'phantom-research',
|
||||||
|
'depends_on_past': False,
|
||||||
|
'email_on_failure': False,
|
||||||
|
'email_on_retry': False,
|
||||||
|
'retries': 2,
|
||||||
|
'retry_delay': timedelta(minutes=5),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeProvider(SupabaseProvider, BackendAPIProvider):
|
||||||
|
def __init__(self):
|
||||||
|
SupabaseProvider.__init__(self)
|
||||||
|
BackendAPIProvider.__init__(self)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_context(store_mode: str = 'hotel') -> PipelineContext:
|
||||||
|
return PipelineContext(provider=CompositeProvider(), store_mode=store_mode)
|
||||||
|
|
||||||
|
|
||||||
|
with DAG(
|
||||||
|
'ml_training_pipeline',
|
||||||
|
default_args=DEFAULT_ARGS,
|
||||||
|
description='ML training data pipeline: fetch -> validate -> extract features -> label -> publish',
|
||||||
|
schedule=None,
|
||||||
|
start_date=days_ago(1),
|
||||||
|
catchup=False,
|
||||||
|
max_active_runs=1,
|
||||||
|
tags=['ml', 'training', 'features', 'research'],
|
||||||
|
) as dag:
|
||||||
|
|
||||||
|
@task
|
||||||
|
def fetch_interactions(**kwargs) -> bytes:
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||||
|
df = FetchInteractionsStep(ctx).transform(None)
|
||||||
|
logging.info(f"Fetched {len(df)} interactions, {df['sessionId'].nunique()} sessions")
|
||||||
|
return pickle.dumps(df)
|
||||||
|
|
||||||
|
@task
|
||||||
|
def validate_data(raw_data: bytes, **kwargs) -> bytes:
|
||||||
|
df = pickle.loads(raw_data)
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||||
|
validated = ValidateDataStep(ctx).transform(df)
|
||||||
|
report = ctx.get_cached('validation_report') or {}
|
||||||
|
logging.info(f"Validation: {report.get('status')}, {report.get('sessions', 0)} sessions")
|
||||||
|
return pickle.dumps(validated)
|
||||||
|
|
||||||
|
@task
|
||||||
|
def extract_session_features(validated_data: bytes, **kwargs) -> bytes:
|
||||||
|
df = pickle.loads(validated_data)
|
||||||
|
if df.empty:
|
||||||
|
logging.warning("Empty input, skipping feature extraction")
|
||||||
|
return pickle.dumps(pd.DataFrame())
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||||
|
features = ExtractSessionFeaturesStep(ctx).transform(df)
|
||||||
|
logging.info(f"Extracted {len(features.columns)} features for {len(features)} sessions")
|
||||||
|
return pickle.dumps(features)
|
||||||
|
|
||||||
|
@task
|
||||||
|
def join_labels(features_data: bytes, **kwargs) -> bytes:
|
||||||
|
features_df = pickle.loads(features_data)
|
||||||
|
if features_df.empty:
|
||||||
|
logging.warning("Empty features, skipping label join")
|
||||||
|
return pickle.dumps(pd.DataFrame())
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||||
|
labeled = JoinLabelsStep(ctx).transform(features_df)
|
||||||
|
n_agents = labeled['is_agent'].sum() if 'is_agent' in labeled.columns else 0
|
||||||
|
logging.info(f"Labeled {len(labeled)} sessions: {n_agents} agents")
|
||||||
|
return pickle.dumps(labeled)
|
||||||
|
|
||||||
|
@task(outlets=[TRAINING_DATASET])
|
||||||
|
def publish_training_data(labeled_data: bytes, **kwargs) -> dict:
|
||||||
|
labeled_df = pickle.loads(labeled_data)
|
||||||
|
if labeled_df.empty:
|
||||||
|
return {'status': 'skipped', 'reason': 'empty_data'}
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'n_sessions': len(labeled_df),
|
||||||
|
'n_features': len([c for c in labeled_df.columns if c not in ['sessionId', 'experimentId', 'is_agent']]),
|
||||||
|
'store_mode': dag_conf.get('store_mode', 'hotel'),
|
||||||
|
'timestamp': pd.Timestamp.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
raw = fetch_interactions()
|
||||||
|
validated = validate_data(raw)
|
||||||
|
features = extract_session_features(validated)
|
||||||
|
labeled = join_labels(features)
|
||||||
|
publish_training_data(labeled)
|
||||||
11
experiments/ml/__init__.py
Normal file
11
experiments/ml/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from .evals import evaluate
|
||||||
|
from .arch import (
|
||||||
|
XGBoostAgentClassifier,
|
||||||
|
LightGBMAgentClassifier
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ =[
|
||||||
|
'evaluate',
|
||||||
|
'XGBoostAgentClassifier',
|
||||||
|
'LightGBMAgentClassifier'
|
||||||
|
]
|
||||||
122
experiments/ml/arch.py
Normal file
122
experiments/ml/arch.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# sklearn compatible models for agent detection
|
||||||
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
|
from procesing.context import PipelineContext
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import xgboost as xgb
|
||||||
|
import lightgbm as lgb
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
TASK = 'classification'
|
||||||
|
LABELS = ['human', 'agent']
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgentClassifier(BaseEstimator, ClassifierMixin, ABC):
|
||||||
|
"""Base class for tree-based agent detection classifiers with common logic"""
|
||||||
|
|
||||||
|
def __init__(self, context: Optional[PipelineContext] = None, n_estimators: int = 200,
|
||||||
|
max_depth: int = 6, learning_rate: float = 0.05,
|
||||||
|
early_stopping_rounds: int = 20):
|
||||||
|
self.context = context
|
||||||
|
self.n_estimators = n_estimators
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.early_stopping_rounds = early_stopping_rounds
|
||||||
|
self.model_ = None
|
||||||
|
self.feature_names_ = None
|
||||||
|
|
||||||
|
def _to_array(self, X):
|
||||||
|
"""Convert pandas structures to numpy arrays"""
|
||||||
|
return X.values if isinstance(X, (pd.DataFrame, pd.Series)) else X
|
||||||
|
|
||||||
|
def _compute_pos_weight(self, y_arr):
|
||||||
|
"""Calculate scale_pos_weight for class imbalance handling"""
|
||||||
|
n_neg, n_pos = (y_arr == 0).sum(), (y_arr == 1).sum()
|
||||||
|
return n_neg / n_pos if n_pos > 0 else 1.0
|
||||||
|
|
||||||
|
def _prepare_eval_set(self, eval_set):
|
||||||
|
"""Convert eval_set to numpy arrays if needed"""
|
||||||
|
if not eval_set:
|
||||||
|
return None
|
||||||
|
X_val, y_val = eval_set[0]
|
||||||
|
return [(self._to_array(X_val), self._to_array(y_val))]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _build_model(self, scale_pos: float):
|
||||||
|
"""Build the underlying model instance (must be implemented by subclasses)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _fit_with_eval(self, X_arr, y_arr, eval_arr):
|
||||||
|
"""Fit model with evaluation set (must be implemented by subclasses)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fit(self, X, y, eval_set=None):
|
||||||
|
X_arr, y_arr = self._to_array(X), self._to_array(y)
|
||||||
|
|
||||||
|
if isinstance(X, pd.DataFrame):
|
||||||
|
self.feature_names_ = X.columns.tolist()
|
||||||
|
|
||||||
|
scale_pos = self._compute_pos_weight(y_arr)
|
||||||
|
self.model_ = self._build_model(scale_pos)
|
||||||
|
|
||||||
|
eval_arr = self._prepare_eval_set(eval_set)
|
||||||
|
if eval_arr:
|
||||||
|
self._fit_with_eval(X_arr, y_arr, eval_arr)
|
||||||
|
else:
|
||||||
|
self.model_.fit(X_arr, y_arr)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
return self.model_.predict(self._to_array(X))
|
||||||
|
|
||||||
|
def predict_proba(self, X):
|
||||||
|
return self.model_.predict_proba(self._to_array(X))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_importances_(self):
|
||||||
|
return self.model_.feature_importances_ if self.model_ else None
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostAgentClassifier(BaseAgentClassifier):
|
||||||
|
"""XGBoost binary classifier for agent detection with class imbalance handling"""
|
||||||
|
|
||||||
|
def _build_model(self, scale_pos: float):
|
||||||
|
return xgb.XGBClassifier(
|
||||||
|
n_estimators=self.n_estimators,
|
||||||
|
max_depth=self.max_depth,
|
||||||
|
learning_rate=self.learning_rate,
|
||||||
|
scale_pos_weight=scale_pos,
|
||||||
|
eval_metric='auc',
|
||||||
|
early_stopping_rounds=self.early_stopping_rounds,
|
||||||
|
random_state=42,
|
||||||
|
tree_method='hist',
|
||||||
|
enable_categorical=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fit_with_eval(self, X_arr, y_arr, eval_arr):
|
||||||
|
self.model_.fit(X_arr, y_arr, eval_set=eval_arr, verbose=False)
|
||||||
|
|
||||||
|
|
||||||
|
class LightGBMAgentClassifier(BaseAgentClassifier):
|
||||||
|
"""LightGBM binary classifier for agent detection with class imbalance handling"""
|
||||||
|
|
||||||
|
def _build_model(self, scale_pos: float):
|
||||||
|
return lgb.LGBMClassifier(
|
||||||
|
n_estimators=self.n_estimators,
|
||||||
|
max_depth=self.max_depth,
|
||||||
|
learning_rate=self.learning_rate,
|
||||||
|
scale_pos_weight=scale_pos,
|
||||||
|
metric='auc',
|
||||||
|
random_state=42,
|
||||||
|
verbosity=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fit_with_eval(self, X_arr, y_arr, eval_arr):
|
||||||
|
self.model_.fit(
|
||||||
|
X_arr, y_arr,
|
||||||
|
eval_set=eval_arr,
|
||||||
|
callbacks=[lgb.early_stopping(self.early_stopping_rounds, verbose=False)]
|
||||||
|
)
|
||||||
103
experiments/ml/evals.py
Normal file
103
experiments/ml/evals.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
|
||||||
|
f1_score, roc_auc_score, confusion_matrix, roc_curve)
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from logging import getLogger
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def log_feature_importance(writer, model, feature_names, epoch):
|
||||||
|
"""Visualize and log feature importance to TensorBoard"""
|
||||||
|
if not hasattr(model, 'feature_importances_') or model.feature_importances_ is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
importance = model.feature_importances_
|
||||||
|
indices = np.argsort(importance)[::-1][:20] # top 20
|
||||||
|
top_features = [feature_names[i] for i in indices]
|
||||||
|
top_importance = importance[indices]
|
||||||
|
|
||||||
|
for i, (feat, imp) in enumerate(zip(top_features, top_importance)):
|
||||||
|
writer.add_scalar(f'FeatureImportance/{feat}', imp, epoch)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
ax.barh(range(len(top_features)), top_importance, align='center')
|
||||||
|
ax.set_yticks(range(len(top_features)))
|
||||||
|
ax.set_yticklabels(top_features)
|
||||||
|
ax.invert_yaxis()
|
||||||
|
ax.set_xlabel('Importance')
|
||||||
|
ax.set_title(f'Top 20 Feature Importance (Epoch {epoch})')
|
||||||
|
ax.grid(axis='x', alpha=0.3)
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(buf, format='png', dpi=100)
|
||||||
|
buf.seek(0)
|
||||||
|
img = Image.open(buf)
|
||||||
|
img_arr = np.array(img)
|
||||||
|
writer.add_image('FeatureImportance/Chart', img_arr, epoch, dataformats='HWC')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int):
|
||||||
|
accuracy = accuracy_score(true_class, perdicted_class)
|
||||||
|
precision = precision_score(true_class, perdicted_class, zero_division=0)
|
||||||
|
recall = recall_score(true_class, perdicted_class, zero_division=0)
|
||||||
|
f1 = f1_score(true_class, perdicted_class, zero_division=0)
|
||||||
|
roc_auc = roc_auc_score(true_class, predicted_proba)
|
||||||
|
|
||||||
|
writer.add_scalar('Eval/Accuracy', accuracy, epoch)
|
||||||
|
writer.add_scalar('Eval/Precision', precision, epoch)
|
||||||
|
writer.add_scalar('Eval/Recall', recall, epoch)
|
||||||
|
writer.add_scalar('Eval/F1_Score', f1, epoch)
|
||||||
|
writer.add_scalar('Eval/ROC_AUC', roc_auc, epoch)
|
||||||
|
|
||||||
|
# confusion matrix
|
||||||
|
cm = confusion_matrix(true_class, perdicted_class)
|
||||||
|
tn, fp, fn, tp = cm.ravel()
|
||||||
|
writer.add_scalar('Eval/TrueNeg', tn, epoch)
|
||||||
|
writer.add_scalar('Eval/FalsePos', fp, epoch)
|
||||||
|
writer.add_scalar('Eval/FalseNeg', fn, epoch)
|
||||||
|
writer.add_scalar('Eval/TruePos', tp, epoch)
|
||||||
|
|
||||||
|
# specificity and sensitivity
|
||||||
|
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
||||||
|
sensitivity = recall # same as recall/TPR
|
||||||
|
writer.add_scalar('Eval/Specificity', specificity, epoch)
|
||||||
|
writer.add_scalar('Eval/Sensitivity', sensitivity, epoch)
|
||||||
|
|
||||||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
||||||
|
ax1.matshow(cm, cmap='Blues', alpha=0.7)
|
||||||
|
for i in range(2):
|
||||||
|
for j in range(2):
|
||||||
|
ax1.text(j, i, str(cm[i, j]), ha='center', va='center', fontsize=14)
|
||||||
|
ax1.set_xlabel('Predicted')
|
||||||
|
ax1.set_ylabel('True')
|
||||||
|
ax1.set_title(f'Confusion Matrix (Epoch {epoch})')
|
||||||
|
ax1.set_xticks([0, 1])
|
||||||
|
ax1.set_yticks([0, 1])
|
||||||
|
ax1.set_xticklabels(['Human', 'Agent'])
|
||||||
|
ax1.set_yticklabels(['Human', 'Agent'])
|
||||||
|
|
||||||
|
# ROC curve
|
||||||
|
fpr, tpr, _ = roc_curve(true_class, predicted_proba)
|
||||||
|
ax2.plot(fpr, tpr, label=f'AUC={roc_auc:.3f}', linewidth=2)
|
||||||
|
ax2.plot([0, 1], [0, 1], 'k--', label='Random')
|
||||||
|
ax2.set_xlabel('False Positive Rate')
|
||||||
|
ax2.set_ylabel('True Positive Rate')
|
||||||
|
ax2.set_title('ROC Curve')
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(alpha=0.3)
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(buf, format='png', dpi=100)
|
||||||
|
buf.seek(0)
|
||||||
|
img = Image.open(buf)
|
||||||
|
img_arr = np.array(img)
|
||||||
|
writer.add_image('Eval/Metrics', img_arr, epoch, dataformats='HWC')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logger.info(f"Eval {epoch}: Acc={accuracy:.4f} Prec={precision:.4f} Rec={recall:.4f} F1={f1:.4f} AUC={roc_auc:.4f}")
|
||||||
6
experiments/ml/requirements.txt
Normal file
6
experiments/ml/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
torch
|
||||||
|
tensorboard
|
||||||
|
fastparquet
|
||||||
|
pyarrow
|
||||||
|
xgboost
|
||||||
|
lightgbm
|
||||||
137
experiments/ml/train.py
Normal file
137
experiments/ml/train.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import joblib
|
||||||
|
from datetime import datetime
|
||||||
|
from ml.evals import evaluate, log_feature_importance
|
||||||
|
from ml.arch import XGBoostAgentClassifier, LightGBMAgentClassifier, LABELS
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
FEATURE_COLS_EXCLUDE = ['sessionId', 'experimentId', 'is_agent', 'xp_human_only', 'xp_market_mode', 'browser_family']
|
||||||
|
RUNS_DIR = Path('ml/runs')
|
||||||
|
CHECKPOINTS_DIR = Path('ml/checkpoints')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(df):
|
||||||
|
"""
|
||||||
|
Prepare feature matrix and labels from raw dataframe
|
||||||
|
Handles missing labels, feature selection, and categorical encoding
|
||||||
|
Returns: (X, y, feature_cols)
|
||||||
|
"""
|
||||||
|
# drop rows with missing labels
|
||||||
|
n_before = len(df)
|
||||||
|
df = df[df['is_agent'].notna()].copy()
|
||||||
|
n_dropped = n_before - len(df)
|
||||||
|
if n_dropped > 0:
|
||||||
|
logger.warning(f"Dropped {n_dropped} sessions with missing labels")
|
||||||
|
|
||||||
|
if len(df) == 0:
|
||||||
|
logger.error("No labeled data available")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
feature_cols = [c for c in df.columns if c not in FEATURE_COLS_EXCLUDE]
|
||||||
|
|
||||||
|
# handle categorical browser_family via one-hot encoding
|
||||||
|
if 'browser_family' in df.columns:
|
||||||
|
browser_dummies = pd.get_dummies(df['browser_family'], prefix='browser', drop_first=True)
|
||||||
|
df = pd.concat([df, browser_dummies], axis=1)
|
||||||
|
feature_cols.extend(browser_dummies.columns.tolist())
|
||||||
|
|
||||||
|
X = df[feature_cols].fillna(0)
|
||||||
|
y = df['is_agent'].astype(int)
|
||||||
|
|
||||||
|
return X, y, feature_cols
|
||||||
|
|
||||||
|
|
||||||
|
def train(data_path=None, model_type='xgboost', test_size=0.2, random_state=42,
|
||||||
|
n_estimators=200, max_depth=6, learning_rate=0.05):
|
||||||
|
"""
|
||||||
|
Train agent detection classifier
|
||||||
|
Args:
|
||||||
|
data_path: path to labeled feature matrix CSV or parquet
|
||||||
|
model_type: 'xgboost' or 'lightgbm'
|
||||||
|
test_size: fraction for test split
|
||||||
|
random_state: seed for reproducibility
|
||||||
|
"""
|
||||||
|
RUNS_DIR.mkdir(exist_ok=True)
|
||||||
|
CHECKPOINTS_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
run_name = f"{model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
writer = SummaryWriter(log_dir=RUNS_DIR / run_name)
|
||||||
|
logger.info(f"Starting training run: {run_name}")
|
||||||
|
|
||||||
|
# load data
|
||||||
|
if data_path is None:
|
||||||
|
logger.error("data_path required")
|
||||||
|
return
|
||||||
|
df = pd.read_parquet(data_path)
|
||||||
|
logger.info(f"Loaded {len(df)} sessions from {data_path}")
|
||||||
|
|
||||||
|
# prepare features and labels
|
||||||
|
if 'is_agent' not in df.columns:
|
||||||
|
logger.error("Missing is_agent column")
|
||||||
|
return
|
||||||
|
|
||||||
|
X, y, feature_cols = prepare_data(df)
|
||||||
|
if X is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# class distribution
|
||||||
|
n_agents = y.sum()
|
||||||
|
n_humans = (y == 0).sum()
|
||||||
|
logger.info(f"Class distribution: {n_humans} humans, {n_agents} agents" + (f" (ratio {n_humans / n_agents:.2f})" if n_agents > 0 else ""))
|
||||||
|
|
||||||
|
# train/test split with stratification
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
X, y, test_size=test_size, random_state=random_state, stratify=y
|
||||||
|
)
|
||||||
|
logger.info(f"Train: {len(X_train)}, Test: {len(X_test)}")
|
||||||
|
|
||||||
|
# init model
|
||||||
|
if model_type == 'xgboost':
|
||||||
|
model = XGBoostAgentClassifier(
|
||||||
|
n_estimators=n_estimators,
|
||||||
|
max_depth=max_depth,
|
||||||
|
learning_rate=learning_rate
|
||||||
|
)
|
||||||
|
elif model_type == 'lightgbm':
|
||||||
|
model = LightGBMAgentClassifier(
|
||||||
|
n_estimators=n_estimators,
|
||||||
|
max_depth=max_depth,
|
||||||
|
learning_rate=learning_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unknown model type: {model_type}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# train with eval set for early stopping
|
||||||
|
model.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||||
|
logger.info("Training complete")
|
||||||
|
|
||||||
|
# evaluate on test set
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
y_prob = model.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
evaluate(y_pred, y_prob, y_test, writer, epoch=0)
|
||||||
|
|
||||||
|
# log feature importance
|
||||||
|
log_feature_importance(writer, model, X.columns.tolist(), epoch=0)
|
||||||
|
|
||||||
|
# save model
|
||||||
|
model_path = CHECKPOINTS_DIR / f"{run_name}.pkl"
|
||||||
|
joblib.dump({'model': model, 'feature_cols': X.columns.tolist(), 'run_name': run_name}, model_path)
|
||||||
|
logger.info(f"Model saved to {model_path}")
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
return model, X.columns.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
data_path = sys.argv[1]
|
||||||
|
model_type = sys.argv[2] if len(sys.argv) > 2 else 'xgboost'
|
||||||
|
train(data_path, model_type=model_type)
|
||||||
@@ -170,3 +170,5 @@ if __name__ == '__main__':
|
|||||||
print(f"Feature matrix: {features.shape}")
|
print(f"Feature matrix: {features.shape}")
|
||||||
print(features.head())
|
print(features.head())
|
||||||
print(features.info())
|
print(features.info())
|
||||||
|
|
||||||
|
features.to_parquet("features.parquet")
|
||||||
|
|||||||
Reference in New Issue
Block a user