diff --git a/experiments/ml/arch.py b/experiments/ml/arch.py new file mode 100644 index 0000000..48b2246 --- /dev/null +++ b/experiments/ml/arch.py @@ -0,0 +1,17 @@ +# this should retrun a model with exposed methods fit and transform method in an sklearn style +from sklearn.base import BaseEstimator, TransformerMixin +from procesing.context import PipelineContext +from typing import Any + +TASK = 'classification' +LABELS = ['agent', 'human'] + +class BaseModel(BaseEstimator, TransformerMixin): + def __init__(self, context: PipelineContext): + self.context = context + + def fit(self, X=None, y=None): + return self + + def transform(self, X) -> Any: + pass diff --git a/experiments/ml/evals.py b/experiments/ml/evals.py new file mode 100644 index 0000000..55c2bf4 --- /dev/null +++ b/experiments/ml/evals.py @@ -0,0 +1,19 @@ +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score +from torch.utils.tensorboard import SummaryWriter +from logging import getLogger +logger = getLogger(__name__) + +def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int): + accuracy = accuracy_score(true_class, perdicted_class) + precision = precision_score(true_class, perdicted_class) + recall = recall_score(true_class, perdicted_class) + f1 = f1_score(true_class, perdicted_class) + roc_auc = roc_auc_score(true_class, predicted_proba) + + writer.add_scalar('Eval/Accuracy', accuracy, epoch) + writer.add_scalar('Eval/Precision', precision, epoch) + writer.add_scalar('Eval/Recall', recall, epoch) + writer.add_scalar('Eval/F1_Score', f1, epoch) + writer.add_scalar('Eval/ROC_AUC', roc_auc, epoch) + + logger.info(f"Eval Metrics - Epoch {epoch}: Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}") diff --git a/experiments/ml/train.py b/experiments/ml/train.py new file mode 100644 index 0000000..cd658f2 --- /dev/null +++ b/experiments/ml/train.py @@ -0,0 +1,13 @@ +from torch.utils.tensorboard import SummaryWriter +from logging import getLogger +from evals import evaluate +logger = getLogger(__name__) + + +def train(): + pass + + + +if __name__ == "__main__": + train()