chore: training and data refactors

This commit is contained in:
2026-01-22 11:40:12 +01:00
parent dee6f573e3
commit c15bb1882e
2 changed files with 47 additions and 15 deletions

View File

@@ -4,16 +4,17 @@ from pathlib import Path
from typing import Dict, Type, Optional
import pickle
from torch.utils.tensorboard import SummaryWriter
from environment import PHANTOMEnv, BusinessLogicConstraints
from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
try:
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
except ImportError:
except ImportError as e:
BasePricingEngine = None # engines not required for basic usage
print(e)
"""
@@ -36,27 +37,49 @@ class EngineTrainer:
self.global_step = 0
def train(self, n_episodes: int, seed: int = 42):
obs, _ = self.env.reset(seed=seed)
prices = None
for ep in range(n_episodes):
prices = self.engine.compute_prices(prices, obs)
obs, reward, done, _, info = self.env.step(prices)
obs, _ = self.env.reset(seed=seed + ep)
self.engine.reset()
done = False
prev_prices = obs["elasticity"]["price"]
episode_reward = 0.0
last_info: Dict[str, float] = {}
while not done:
action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(action_prices)
self.engine.update(obs, reward, done, info)
episode_reward += reward
prev_prices = obs["elasticity"]["price"]
last_info = info
if self.tb_writer:
self.tb_writer.add_scalar("reward/step", reward, self.global_step)
if "coi" in info:
self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step)
if "alpha_hat" in info:
self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step)
self.global_step += 1
last_info = dict(last_info)
last_info.update({"episode_reward": episode_reward, "episode": ep})
self.episode_metrics.append(last_info)
if self.tb_writer:
self.tb_writer.add_scalar("reward/episode", episode_reward, ep)
return self
def run_episode(self, seed: int = 42) -> Dict:
"""run single evaluation episode and return metrics"""
obs, _ = self.env.reset(seed=seed)
self.engine.reset()
total_reward, prices = 0.0, None
total_reward = 0.0
prev_prices = obs["elasticity"]["price"]
ep_metrics = {'total_reward': 0.0}
done = False
while not done:
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"]
obs, reward, done, _, info = self.env.step(prices)
action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(action_prices)
total_reward += reward
for k, v in info.items():
ep_metrics[k] = v
prev_prices = obs["elasticity"]["price"]
ep_metrics['total_reward'] = total_reward
return ep_metrics
@@ -106,7 +129,7 @@ if __name__ == "__main__":
logger.error("Engines not available, cannot run training")
exit(1)
base_dir = Path("./runs")
base_dir = Path("./sim/rl/runs")
base_dir.mkdir(exist_ok=True)
engines = {

View File

@@ -1,4 +1,9 @@
import os, requests, py7zr
import os
import requests
try:
import py7zr # type: ignore
except ImportError: # pragma: no cover - optional dependency
py7zr = None
import pandas as pd
from typing import Generator
try:
@@ -22,12 +27,16 @@ class YooChooseLoader(Loader):
self.entries = list(self.data.keys())
def _setup(self):
if py7zr is None:
raise RuntimeError("py7zr is required to unpack YooChoose dataset. Install py7zr first.")
os.makedirs(self.root, exist_ok=True)
zip_path = f"{self.root}/temp.7z"
with requests.get(self.URL, stream=True) as r:
with open(zip_path, 'wb') as f:
for chunk in r.iter_content(8192): f.write(chunk)
with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root)
for chunk in r.iter_content(8192):
f.write(chunk)
with py7zr.SevenZipFile(zip_path, 'r') as z:
z.extractall(self.root)
os.remove(zip_path)
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel: