chore: training and data refactors

This commit is contained in:
2026-01-22 11:40:12 +01:00
parent dee6f573e3
commit c15bb1882e
2 changed files with 47 additions and 15 deletions

View File

@@ -4,16 +4,17 @@ from pathlib import Path
from typing import Dict, Type, Optional from typing import Dict, Type, Optional
import pickle import pickle
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from environment import PHANTOMEnv, BusinessLogicConstraints from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine, from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine) SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
except ImportError: except ImportError as e:
BasePricingEngine = None # engines not required for basic usage BasePricingEngine = None # engines not required for basic usage
print(e)
""" """
@@ -36,27 +37,49 @@ class EngineTrainer:
self.global_step = 0 self.global_step = 0
def train(self, n_episodes: int, seed: int = 42): def train(self, n_episodes: int, seed: int = 42):
obs, _ = self.env.reset(seed=seed)
prices = None
for ep in range(n_episodes): for ep in range(n_episodes):
prices = self.engine.compute_prices(prices, obs) obs, _ = self.env.reset(seed=seed + ep)
obs, reward, done, _, info = self.env.step(prices) self.engine.reset()
done = False
prev_prices = obs["elasticity"]["price"]
episode_reward = 0.0
last_info: Dict[str, float] = {}
while not done:
action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(action_prices)
self.engine.update(obs, reward, done, info) self.engine.update(obs, reward, done, info)
episode_reward += reward
prev_prices = obs["elasticity"]["price"]
last_info = info
if self.tb_writer:
self.tb_writer.add_scalar("reward/step", reward, self.global_step)
if "coi" in info:
self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step)
if "alpha_hat" in info:
self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step)
self.global_step += 1
last_info = dict(last_info)
last_info.update({"episode_reward": episode_reward, "episode": ep})
self.episode_metrics.append(last_info)
if self.tb_writer:
self.tb_writer.add_scalar("reward/episode", episode_reward, ep)
return self return self
def run_episode(self, seed: int = 42) -> Dict: def run_episode(self, seed: int = 42) -> Dict:
"""run single evaluation episode and return metrics""" """run single evaluation episode and return metrics"""
obs, _ = self.env.reset(seed=seed) obs, _ = self.env.reset(seed=seed)
self.engine.reset() self.engine.reset()
total_reward, prices = 0.0, None total_reward = 0.0
prev_prices = obs["elasticity"]["price"]
ep_metrics = {'total_reward': 0.0} ep_metrics = {'total_reward': 0.0}
done = False done = False
while not done: while not done:
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"] action_prices = self.engine.compute_prices(prev_prices, obs)
obs, reward, done, _, info = self.env.step(prices) obs, reward, done, _, info = self.env.step(action_prices)
total_reward += reward total_reward += reward
for k, v in info.items(): for k, v in info.items():
ep_metrics[k] = v ep_metrics[k] = v
prev_prices = obs["elasticity"]["price"]
ep_metrics['total_reward'] = total_reward ep_metrics['total_reward'] = total_reward
return ep_metrics return ep_metrics
@@ -106,7 +129,7 @@ if __name__ == "__main__":
logger.error("Engines not available, cannot run training") logger.error("Engines not available, cannot run training")
exit(1) exit(1)
base_dir = Path("./runs") base_dir = Path("./sim/rl/runs")
base_dir.mkdir(exist_ok=True) base_dir.mkdir(exist_ok=True)
engines = { engines = {

View File

@@ -1,4 +1,9 @@
import os, requests, py7zr import os
import requests
try:
import py7zr # type: ignore
except ImportError: # pragma: no cover - optional dependency
py7zr = None
import pandas as pd import pandas as pd
from typing import Generator from typing import Generator
try: try:
@@ -22,12 +27,16 @@ class YooChooseLoader(Loader):
self.entries = list(self.data.keys()) self.entries = list(self.data.keys())
def _setup(self): def _setup(self):
if py7zr is None:
raise RuntimeError("py7zr is required to unpack YooChoose dataset. Install py7zr first.")
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
zip_path = f"{self.root}/temp.7z" zip_path = f"{self.root}/temp.7z"
with requests.get(self.URL, stream=True) as r: with requests.get(self.URL, stream=True) as r:
with open(zip_path, 'wb') as f: with open(zip_path, 'wb') as f:
for chunk in r.iter_content(8192): f.write(chunk) for chunk in r.iter_content(8192):
with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root) f.write(chunk)
with py7zr.SevenZipFile(zip_path, 'r') as z:
z.extractall(self.root)
os.remove(zip_path) os.remove(zip_path)
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel: def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel: