chore: training and data refactors

This commit is contained in:
2026-01-22 11:40:12 +01:00
parent dee6f573e3
commit c15bb1882e
2 changed files with 47 additions and 15 deletions

View File

@@ -4,16 +4,17 @@ from pathlib import Path
from typing import Dict, Type, Optional
import pickle
from torch.utils.tensorboard import SummaryWriter
from environment import PHANTOMEnv, BusinessLogicConstraints
from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
try:
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
except ImportError:
except ImportError as e:
BasePricingEngine = None # engines not required for basic usage
print(e)
"""
@@ -36,27 +37,49 @@ class EngineTrainer:
self.global_step = 0
def train(self, n_episodes: int, seed: int = 42):
obs, _ = self.env.reset(seed=seed)
prices = None
for ep in range(n_episodes):
prices = self.engine.compute_prices(prices, obs)
obs, reward, done, _, info = self.env.step(prices)
self.engine.update(obs, reward, done, info)
obs, _ = self.env.reset(seed=seed + ep)
self.engine.reset()
done = False
prev_prices = obs["elasticity"]["price"]
episode_reward = 0.0
last_info: Dict[str, float] = {}
while not done:
action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(action_prices)
self.engine.update(obs, reward, done, info)
episode_reward += reward
prev_prices = obs["elasticity"]["price"]
last_info = info
if self.tb_writer:
self.tb_writer.add_scalar("reward/step", reward, self.global_step)
if "coi" in info:
self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step)
if "alpha_hat" in info:
self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step)
self.global_step += 1
last_info = dict(last_info)
last_info.update({"episode_reward": episode_reward, "episode": ep})
self.episode_metrics.append(last_info)
if self.tb_writer:
self.tb_writer.add_scalar("reward/episode", episode_reward, ep)
return self
def run_episode(self, seed: int = 42) -> Dict:
"""run single evaluation episode and return metrics"""
obs, _ = self.env.reset(seed=seed)
self.engine.reset()
total_reward, prices = 0.0, None
total_reward = 0.0
prev_prices = obs["elasticity"]["price"]
ep_metrics = {'total_reward': 0.0}
done = False
while not done:
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"]
obs, reward, done, _, info = self.env.step(prices)
action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(action_prices)
total_reward += reward
for k, v in info.items():
ep_metrics[k] = v
prev_prices = obs["elasticity"]["price"]
ep_metrics['total_reward'] = total_reward
return ep_metrics
@@ -106,7 +129,7 @@ if __name__ == "__main__":
logger.error("Engines not available, cannot run training")
exit(1)
base_dir = Path("./runs")
base_dir = Path("./sim/rl/runs")
base_dir.mkdir(exist_ok=True)
engines = {