feat: separating modules and adding training logs paths

This commit is contained in:
2025-12-12 12:45:51 +01:00
parent 0119408897
commit 48cf50db32
2 changed files with 160 additions and 4 deletions

View File

@@ -9,6 +9,38 @@ from PIL import Image
logger = getLogger(__name__)
def log_feature_importance(writer, model, feature_names, epoch):
"""Visualize and log feature importance to TensorBoard"""
if not hasattr(model, 'feature_importances_') or model.feature_importances_ is None:
return
importance = model.feature_importances_
indices = np.argsort(importance)[::-1][:20] # top 20
top_features = [feature_names[i] for i in indices]
top_importance = importance[indices]
for i, (feat, imp) in enumerate(zip(top_features, top_importance)):
writer.add_scalar(f'FeatureImportance/{feat}', imp, epoch)
fig, ax = plt.subplots(figsize=(10, 8))
ax.barh(range(len(top_features)), top_importance, align='center')
ax.set_yticks(range(len(top_features)))
ax.set_yticklabels(top_features)
ax.invert_yaxis()
ax.set_xlabel('Importance')
ax.set_title(f'Top 20 Feature Importance (Epoch {epoch})')
ax.grid(axis='x', alpha=0.3)
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0)
img = Image.open(buf)
img_arr = np.array(img)
writer.add_image('FeatureImportance/Chart', img_arr, epoch, dataformats='HWC')
plt.close()
def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int):
accuracy = accuracy_score(true_class, perdicted_class)
precision = precision_score(true_class, perdicted_class, zero_division=0)