feat: contaminator and training

This commit is contained in:
2026-01-21 19:12:56 +01:00
parent 2ed200f870
commit dee6f573e3
2 changed files with 100 additions and 76 deletions

View File

@@ -3,15 +3,17 @@ import logging
from pathlib import Path
from typing import Dict, Type, Optional
import pickle
from torch import neg_
from torch.utils.tensorboard import SummaryWriter
from environment import PHANTOMEnv, FastTrainingConstraints, BusinessLogicConstraints
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
from environment import PHANTOMEnv, BusinessLogicConstraints
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
try:
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
except ImportError:
BasePricingEngine = None # engines not required for basic usage
"""
@@ -26,8 +28,7 @@ CURRENT SOLUTION BELOW does not implement correct learning or updates.
class EngineTrainer:
"""wrapper to run pricing engines through episodes and collect metrics"""
def __init__(self, engine: BasePricingEngine, env: PHANTOMEnv,
tb_writer: Optional[SummaryWriter] = None):
def __init__(self, engine, env: PHANTOMEnv, tb_writer: Optional[SummaryWriter] = None):
self.engine = engine
self.env = env
self.episode_metrics = []
@@ -35,7 +36,6 @@ class EngineTrainer:
self.global_step = 0
def train(self, n_episodes: int, seed: int = 42):
obs, _ = self.env.reset(seed=seed)
prices = None
for ep in range(n_episodes):
@@ -44,12 +44,21 @@ class EngineTrainer:
self.engine.update(obs, reward, done, info)
return self
return self.episode_metrics
def run_episode(self, seed: int = 42) -> Dict:
"""run single evaluation episode and return metrics"""
obs, _ = self.env.reset(seed=seed)
self.engine.reset()
total_reward, prices = 0.0, None
ep_metrics = {'total_reward': 0.0}
done = False
while not done:
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"]
obs, reward, done, _, info = self.env.step(prices)
total_reward += reward
for k, v in info.items():
ep_metrics[k] = v
ep_metrics['total_reward'] = total_reward
return ep_metrics
def evaluate(self, n_episodes: int = 10, seed: int = 100) -> Dict:
"""evaluate trained engine"""
@@ -57,17 +66,16 @@ class EngineTrainer:
'agent_loss', 'ux_volatility', 'look_to_book']}
for ep in range(n_episodes):
metrics = self.run_episode(seed=seed + ep)
for k in results: results[k].append(metrics[k])
for k in results:
results[k].append(metrics.get(k, 0.0))
return {k: (np.mean(v), np.std(v)) for k, v in results.items()}
def make_env(fast: bool = True):
constraints = FastTrainingConstraints() if fast else BusinessLogicConstraints()
return PHANTOMEnv(constraints=constraints)
def make_env():
return PHANTOMEnv(constraints=BusinessLogicConstraints())
def train_engine(engine_cls: Type[BasePricingEngine], env: PHANTOMEnv,
n_episodes: int, seed: int = 42,
def train_engine(engine_cls, env: PHANTOMEnv, n_episodes: int, seed: int = 42,
tb_writer: Optional[SummaryWriter] = None) -> EngineTrainer:
constraints = env.constraints
engine = engine_cls(constraints=constraints, seed=seed)
@@ -80,15 +88,11 @@ def save_trainer(trainer: EngineTrainer, path: Path):
"""save engine state and metrics"""
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'wb') as f:
pickle.dump({
'engine': trainer.engine,
'metrics': trainer.episode_metrics
}, f)
pickle.dump({'engine': trainer.engine, 'metrics': trainer.episode_metrics}, f)
logger.info(f"Saved trainer to {path}")
def load_trainer(path: Path, env: PHANTOMEnv,
tb_writer: Optional[SummaryWriter] = None) -> EngineTrainer:
def load_trainer(path: Path, env: PHANTOMEnv, tb_writer: Optional[SummaryWriter] = None) -> EngineTrainer:
"""load saved engine"""
with open(path, 'rb') as f:
data = pickle.load(f)
@@ -98,45 +102,44 @@ def load_trainer(path: Path, env: PHANTOMEnv,
if __name__ == "__main__":
if BasePricingEngine is None:
logger.error("Engines not available, cannot run training")
exit(1)
base_dir = Path("./runs")
base_dir.mkdir(exist_ok=True)
engines = {
"Wild": WildPricingEngine,
"Static": StaticPricingEngine,
# "SimpleDemand": SimpleDemandEngine,
"RandomWalk": RandomWalkEngine,
"ThompsonSampling": ThompsonSamplingEngine,
}
defenses = [False, True]
n_train_episodes = 50
n_eval_episodes = 10
seed = 42
fast_mode = True
logger.info(f"Training config: {n_train_episodes} episodes per engine, fast_mode={fast_mode}")
logger.info(f"Training config: {n_train_episodes} episodes per engine")
trained_trainers = {}
for engine_name, engine_cls in engines.items():
for use_defense in defenses:
defense_label = "defense_on" if use_defense else "defense_off"
run_name = f"{engine_name}_{defense_label}"
log_dir = base_dir / run_name
log_dir.mkdir(parents=True, exist_ok=True)
run_name = engine_name
log_dir = base_dir / run_name
log_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Training {engine_name} with defense={use_defense}")
logger.info(f"Log directory: {log_dir}")
logger.info(f"Training {engine_name}")
logger.info(f"Log directory: {log_dir}")
env = make_env(fast=fast_mode)
tb_writer = SummaryWriter(log_dir=str(log_dir))
trainer = train_engine(engine_cls, env, n_train_episodes, seed, tb_writer=tb_writer)
tb_writer.close()
env = make_env()
tb_writer = SummaryWriter(log_dir=str(log_dir))
trainer = train_engine(engine_cls, env, n_train_episodes, seed, tb_writer=tb_writer)
tb_writer.close()
save_path = log_dir / "trainer.pkl"
save_trainer(trainer, save_path)
save_path = log_dir / "trainer.pkl"
save_trainer(trainer, save_path)
trained_trainers[run_name] = (trainer, env)
trained_trainers[run_name] = (trainer, env)
logger.info("Starting evaluation")