mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
Update experiments/ml/train.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
48cf50db32
commit
18bd11c09f
@@ -83,7 +83,7 @@ def train(data_path=None, model_type='xgboost', test_size=0.2, random_state=42,
|
|||||||
# class distribution
|
# class distribution
|
||||||
n_agents = y.sum()
|
n_agents = y.sum()
|
||||||
n_humans = (y == 0).sum()
|
n_humans = (y == 0).sum()
|
||||||
logger.info(f"Class distribution: {n_humans} humans, {n_agents} agents (ratio {n_humans/n_agents:.2f})")
|
logger.info(f"Class distribution: {n_humans} humans, {n_agents} agents" + (f" (ratio {n_humans / n_agents:.2f})" if n_agents > 0 else ""))
|
||||||
|
|
||||||
# train/test split with stratification
|
# train/test split with stratification
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
|||||||
Reference in New Issue
Block a user