34 initial discriminator of interaction data (#38)

* feat: training pipeline + tensorboard

* tesnorboard forgot

* chore: ml basic boilerplate

* feat: naive architecture as start

* eval setup

* chore: parquet exporting of data

* chore: updating requirements necesary

* feat: separating modules and adding training logs paths

* Update experiments/ml/train.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fix: new path for runs

* fix: undoing ai slop code

* chore: modules and reqs

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Alves Rösel
2025-12-14 18:58:42 +01:00
committed by GitHub
parent a1916c966c
commit f2271e368e
8 changed files with 507 additions and 0 deletions

103
experiments/ml/evals.py Normal file
View File

@@ -0,0 +1,103 @@
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
f1_score, roc_auc_score, confusion_matrix, roc_curve)
from torch.utils.tensorboard import SummaryWriter
from logging import getLogger
import numpy as np
import matplotlib.pyplot as plt
import io
from PIL import Image
logger = getLogger(__name__)
def log_feature_importance(writer, model, feature_names, epoch):
"""Visualize and log feature importance to TensorBoard"""
if not hasattr(model, 'feature_importances_') or model.feature_importances_ is None:
return
importance = model.feature_importances_
indices = np.argsort(importance)[::-1][:20] # top 20
top_features = [feature_names[i] for i in indices]
top_importance = importance[indices]
for i, (feat, imp) in enumerate(zip(top_features, top_importance)):
writer.add_scalar(f'FeatureImportance/{feat}', imp, epoch)
fig, ax = plt.subplots(figsize=(10, 8))
ax.barh(range(len(top_features)), top_importance, align='center')
ax.set_yticks(range(len(top_features)))
ax.set_yticklabels(top_features)
ax.invert_yaxis()
ax.set_xlabel('Importance')
ax.set_title(f'Top 20 Feature Importance (Epoch {epoch})')
ax.grid(axis='x', alpha=0.3)
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0)
img = Image.open(buf)
img_arr = np.array(img)
writer.add_image('FeatureImportance/Chart', img_arr, epoch, dataformats='HWC')
plt.close()
def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int):
accuracy = accuracy_score(true_class, perdicted_class)
precision = precision_score(true_class, perdicted_class, zero_division=0)
recall = recall_score(true_class, perdicted_class, zero_division=0)
f1 = f1_score(true_class, perdicted_class, zero_division=0)
roc_auc = roc_auc_score(true_class, predicted_proba)
writer.add_scalar('Eval/Accuracy', accuracy, epoch)
writer.add_scalar('Eval/Precision', precision, epoch)
writer.add_scalar('Eval/Recall', recall, epoch)
writer.add_scalar('Eval/F1_Score', f1, epoch)
writer.add_scalar('Eval/ROC_AUC', roc_auc, epoch)
# confusion matrix
cm = confusion_matrix(true_class, perdicted_class)
tn, fp, fn, tp = cm.ravel()
writer.add_scalar('Eval/TrueNeg', tn, epoch)
writer.add_scalar('Eval/FalsePos', fp, epoch)
writer.add_scalar('Eval/FalseNeg', fn, epoch)
writer.add_scalar('Eval/TruePos', tp, epoch)
# specificity and sensitivity
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
sensitivity = recall # same as recall/TPR
writer.add_scalar('Eval/Specificity', specificity, epoch)
writer.add_scalar('Eval/Sensitivity', sensitivity, epoch)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.matshow(cm, cmap='Blues', alpha=0.7)
for i in range(2):
for j in range(2):
ax1.text(j, i, str(cm[i, j]), ha='center', va='center', fontsize=14)
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')
ax1.set_title(f'Confusion Matrix (Epoch {epoch})')
ax1.set_xticks([0, 1])
ax1.set_yticks([0, 1])
ax1.set_xticklabels(['Human', 'Agent'])
ax1.set_yticklabels(['Human', 'Agent'])
# ROC curve
fpr, tpr, _ = roc_curve(true_class, predicted_proba)
ax2.plot(fpr, tpr, label=f'AUC={roc_auc:.3f}', linewidth=2)
ax2.plot([0, 1], [0, 1], 'k--', label='Random')
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curve')
ax2.legend()
ax2.grid(alpha=0.3)
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0)
img = Image.open(buf)
img_arr = np.array(img)
writer.add_image('Eval/Metrics', img_arr, epoch, dataformats='HWC')
plt.close()
logger.info(f"Eval {epoch}: Acc={accuracy:.4f} Prec={precision:.4f} Rec={recall:.4f} F1={f1:.4f} AUC={roc_auc:.4f}")