From c15bb1882e2e7ab34c978c3f470beab56f9ddab1 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Thu, 22 Jan 2026 11:40:12 +0100 Subject: [PATCH] chore: training and data refactors --- sim/rl/train.py | 47 ++++++++++++++++++++++++++++---------- sim/strong_learner/data.py | 15 +++++++++--- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/sim/rl/train.py b/sim/rl/train.py index 01e6809..1d21f24 100644 --- a/sim/rl/train.py +++ b/sim/rl/train.py @@ -4,16 +4,17 @@ from pathlib import Path from typing import Dict, Type, Optional import pickle from torch.utils.tensorboard import SummaryWriter -from environment import PHANTOMEnv, BusinessLogicConstraints +from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') logger = logging.getLogger(__name__) try: - from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine, + from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine, SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine) -except ImportError: +except ImportError as e: BasePricingEngine = None # engines not required for basic usage + print(e) """ @@ -36,27 +37,49 @@ class EngineTrainer: self.global_step = 0 def train(self, n_episodes: int, seed: int = 42): - obs, _ = self.env.reset(seed=seed) - prices = None for ep in range(n_episodes): - prices = self.engine.compute_prices(prices, obs) - obs, reward, done, _, info = self.env.step(prices) - self.engine.update(obs, reward, done, info) + obs, _ = self.env.reset(seed=seed + ep) + self.engine.reset() + done = False + prev_prices = obs["elasticity"]["price"] + episode_reward = 0.0 + last_info: Dict[str, float] = {} + while not done: + action_prices = self.engine.compute_prices(prev_prices, obs) + obs, reward, done, _, info = self.env.step(action_prices) + self.engine.update(obs, reward, done, info) + episode_reward += reward + prev_prices = obs["elasticity"]["price"] + last_info = info + if self.tb_writer: + self.tb_writer.add_scalar("reward/step", reward, self.global_step) + if "coi" in info: + self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step) + if "alpha_hat" in info: + self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step) + self.global_step += 1 + last_info = dict(last_info) + last_info.update({"episode_reward": episode_reward, "episode": ep}) + self.episode_metrics.append(last_info) + if self.tb_writer: + self.tb_writer.add_scalar("reward/episode", episode_reward, ep) return self def run_episode(self, seed: int = 42) -> Dict: """run single evaluation episode and return metrics""" obs, _ = self.env.reset(seed=seed) self.engine.reset() - total_reward, prices = 0.0, None + total_reward = 0.0 + prev_prices = obs["elasticity"]["price"] ep_metrics = {'total_reward': 0.0} done = False while not done: - prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"] - obs, reward, done, _, info = self.env.step(prices) + action_prices = self.engine.compute_prices(prev_prices, obs) + obs, reward, done, _, info = self.env.step(action_prices) total_reward += reward for k, v in info.items(): ep_metrics[k] = v + prev_prices = obs["elasticity"]["price"] ep_metrics['total_reward'] = total_reward return ep_metrics @@ -106,7 +129,7 @@ if __name__ == "__main__": logger.error("Engines not available, cannot run training") exit(1) - base_dir = Path("./runs") + base_dir = Path("./sim/rl/runs") base_dir.mkdir(exist_ok=True) engines = { diff --git a/sim/strong_learner/data.py b/sim/strong_learner/data.py index 80129aa..e22c7db 100644 --- a/sim/strong_learner/data.py +++ b/sim/strong_learner/data.py @@ -1,4 +1,9 @@ -import os, requests, py7zr +import os +import requests +try: + import py7zr # type: ignore +except ImportError: # pragma: no cover - optional dependency + py7zr = None import pandas as pd from typing import Generator try: @@ -22,12 +27,16 @@ class YooChooseLoader(Loader): self.entries = list(self.data.keys()) def _setup(self): + if py7zr is None: + raise RuntimeError("py7zr is required to unpack YooChoose dataset. Install py7zr first.") os.makedirs(self.root, exist_ok=True) zip_path = f"{self.root}/temp.7z" with requests.get(self.URL, stream=True) as r: with open(zip_path, 'wb') as f: - for chunk in r.iter_content(8192): f.write(chunk) - with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root) + for chunk in r.iter_content(8192): + f.write(chunk) + with py7zr.SevenZipFile(zip_path, 'r') as z: + z.extractall(self.root) os.remove(zip_path) def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel: