mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
feat: separating modules and adding training logs paths
This commit is contained in:
@@ -9,6 +9,38 @@ from PIL import Image
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def log_feature_importance(writer, model, feature_names, epoch):
|
||||
"""Visualize and log feature importance to TensorBoard"""
|
||||
if not hasattr(model, 'feature_importances_') or model.feature_importances_ is None:
|
||||
return
|
||||
|
||||
importance = model.feature_importances_
|
||||
indices = np.argsort(importance)[::-1][:20] # top 20
|
||||
top_features = [feature_names[i] for i in indices]
|
||||
top_importance = importance[indices]
|
||||
|
||||
for i, (feat, imp) in enumerate(zip(top_features, top_importance)):
|
||||
writer.add_scalar(f'FeatureImportance/{feat}', imp, epoch)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
ax.barh(range(len(top_features)), top_importance, align='center')
|
||||
ax.set_yticks(range(len(top_features)))
|
||||
ax.set_yticklabels(top_features)
|
||||
ax.invert_yaxis()
|
||||
ax.set_xlabel('Importance')
|
||||
ax.set_title(f'Top 20 Feature Importance (Epoch {epoch})')
|
||||
ax.grid(axis='x', alpha=0.3)
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.tight_layout()
|
||||
plt.savefig(buf, format='png', dpi=100)
|
||||
buf.seek(0)
|
||||
img = Image.open(buf)
|
||||
img_arr = np.array(img)
|
||||
writer.add_image('FeatureImportance/Chart', img_arr, epoch, dataformats='HWC')
|
||||
plt.close()
|
||||
|
||||
def evaluate(perdicted_class, predicted_proba, true_class, writer: SummaryWriter, epoch: int):
|
||||
accuracy = accuracy_score(true_class, perdicted_class)
|
||||
precision = precision_score(true_class, perdicted_class, zero_division=0)
|
||||
|
||||
Reference in New Issue
Block a user