mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
feat: consistent failure case
This commit is contained in:
@@ -2,11 +2,6 @@
|
||||
|
||||
Trains pricing policies using stable-baselines3 with TensorBoard logging.
|
||||
Tracks COI erosion, alpha estimation error, and economic KPIs per thesis formulation.
|
||||
|
||||
Usage:
|
||||
python -m lab.case.thesis.train --algo ppo --alpha 0.3 --steps 100000
|
||||
python -m lab.case.thesis.train --algo adaptive --sweep # run alpha sweep
|
||||
tensorboard --logdir lab/case/thesis/runs
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
@@ -41,9 +36,9 @@ class EpisodeMetrics:
|
||||
reward: float = 0.0
|
||||
revenue: float = 0.0
|
||||
profit: float = 0.0
|
||||
coi_erosion: float = 0.0 # theorem 1: order statistic erosion
|
||||
coi_leakage: float = 0.0 # per-step leakage penalty
|
||||
alpha_error: float = 0.0 # |α - α̂|
|
||||
coi_erosion: float = 0.0
|
||||
coi_leakage: float = 0.0
|
||||
alpha_error: float = 0.0
|
||||
avg_margin: float = 0.0
|
||||
n_agents: int = 0
|
||||
steps: int = 0
|
||||
@@ -213,6 +208,7 @@ def train(cfg: ExperimentConfig) -> Dict[str, Any]:
|
||||
if algo_cls is None:
|
||||
raise ValueError(f"unknown algo: {cfg.algo}")
|
||||
common = dict(verbose=1, seed=cfg.seed, tensorboard_log=str(log_path), device="auto")
|
||||
# TODO: setup hyper parameter passing to train different variations (no free lunch)
|
||||
if cfg.algo.lower() == "ppo":
|
||||
model = PPO("MlpPolicy", train_env, learning_rate=3e-4, n_steps=2048,
|
||||
batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95,
|
||||
|
||||
Reference in New Issue
Block a user