mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
chore: ml basic boilerplate
This commit is contained in:
17
experiments/ml/arch.py
Normal file
17
experiments/ml/arch.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# this should retrun a model with exposed methods fit and transform method in an sklearn style
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
from procesing.context import PipelineContext
|
||||
from typing import Any
|
||||
|
||||
TASK = 'classification'
|
||||
LABELS = ['agent', 'human']
|
||||
|
||||
class BaseModel(BaseEstimator, TransformerMixin):
|
||||
def __init__(self, context: PipelineContext):
|
||||
self.context = context
|
||||
|
||||
def fit(self, X=None, y=None):
|
||||
return self
|
||||
|
||||
def transform(self, X) -> Any:
|
||||
pass
|
||||
19
experiments/ml/evals.py
Normal file
19
experiments/ml/evals.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int):
|
||||
accuracy = accuracy_score(true_class, perdicted_class)
|
||||
precision = precision_score(true_class, perdicted_class)
|
||||
recall = recall_score(true_class, perdicted_class)
|
||||
f1 = f1_score(true_class, perdicted_class)
|
||||
roc_auc = roc_auc_score(true_class, predicted_proba)
|
||||
|
||||
writer.add_scalar('Eval/Accuracy', accuracy, epoch)
|
||||
writer.add_scalar('Eval/Precision', precision, epoch)
|
||||
writer.add_scalar('Eval/Recall', recall, epoch)
|
||||
writer.add_scalar('Eval/F1_Score', f1, epoch)
|
||||
writer.add_scalar('Eval/ROC_AUC', roc_auc, epoch)
|
||||
|
||||
logger.info(f"Eval Metrics - Epoch {epoch}: Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}")
|
||||
13
experiments/ml/train.py
Normal file
13
experiments/ml/train.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from logging import getLogger
|
||||
from evals import evaluate
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def train():
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user