mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: training and data refactors
This commit is contained in:
@@ -4,16 +4,17 @@ from pathlib import Path
|
||||
from typing import Dict, Type, Optional
|
||||
import pickle
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from environment import PHANTOMEnv, BusinessLogicConstraints
|
||||
from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
|
||||
from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
|
||||
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
BasePricingEngine = None # engines not required for basic usage
|
||||
print(e)
|
||||
|
||||
|
||||
"""
|
||||
@@ -36,27 +37,49 @@ class EngineTrainer:
|
||||
self.global_step = 0
|
||||
|
||||
def train(self, n_episodes: int, seed: int = 42):
|
||||
obs, _ = self.env.reset(seed=seed)
|
||||
prices = None
|
||||
for ep in range(n_episodes):
|
||||
prices = self.engine.compute_prices(prices, obs)
|
||||
obs, reward, done, _, info = self.env.step(prices)
|
||||
self.engine.update(obs, reward, done, info)
|
||||
obs, _ = self.env.reset(seed=seed + ep)
|
||||
self.engine.reset()
|
||||
done = False
|
||||
prev_prices = obs["elasticity"]["price"]
|
||||
episode_reward = 0.0
|
||||
last_info: Dict[str, float] = {}
|
||||
while not done:
|
||||
action_prices = self.engine.compute_prices(prev_prices, obs)
|
||||
obs, reward, done, _, info = self.env.step(action_prices)
|
||||
self.engine.update(obs, reward, done, info)
|
||||
episode_reward += reward
|
||||
prev_prices = obs["elasticity"]["price"]
|
||||
last_info = info
|
||||
if self.tb_writer:
|
||||
self.tb_writer.add_scalar("reward/step", reward, self.global_step)
|
||||
if "coi" in info:
|
||||
self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step)
|
||||
if "alpha_hat" in info:
|
||||
self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step)
|
||||
self.global_step += 1
|
||||
last_info = dict(last_info)
|
||||
last_info.update({"episode_reward": episode_reward, "episode": ep})
|
||||
self.episode_metrics.append(last_info)
|
||||
if self.tb_writer:
|
||||
self.tb_writer.add_scalar("reward/episode", episode_reward, ep)
|
||||
return self
|
||||
|
||||
def run_episode(self, seed: int = 42) -> Dict:
|
||||
"""run single evaluation episode and return metrics"""
|
||||
obs, _ = self.env.reset(seed=seed)
|
||||
self.engine.reset()
|
||||
total_reward, prices = 0.0, None
|
||||
total_reward = 0.0
|
||||
prev_prices = obs["elasticity"]["price"]
|
||||
ep_metrics = {'total_reward': 0.0}
|
||||
done = False
|
||||
while not done:
|
||||
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"]
|
||||
obs, reward, done, _, info = self.env.step(prices)
|
||||
action_prices = self.engine.compute_prices(prev_prices, obs)
|
||||
obs, reward, done, _, info = self.env.step(action_prices)
|
||||
total_reward += reward
|
||||
for k, v in info.items():
|
||||
ep_metrics[k] = v
|
||||
prev_prices = obs["elasticity"]["price"]
|
||||
ep_metrics['total_reward'] = total_reward
|
||||
return ep_metrics
|
||||
|
||||
@@ -106,7 +129,7 @@ if __name__ == "__main__":
|
||||
logger.error("Engines not available, cannot run training")
|
||||
exit(1)
|
||||
|
||||
base_dir = Path("./runs")
|
||||
base_dir = Path("./sim/rl/runs")
|
||||
base_dir.mkdir(exist_ok=True)
|
||||
|
||||
engines = {
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import os, requests, py7zr
|
||||
import os
|
||||
import requests
|
||||
try:
|
||||
import py7zr # type: ignore
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
py7zr = None
|
||||
import pandas as pd
|
||||
from typing import Generator
|
||||
try:
|
||||
@@ -22,12 +27,16 @@ class YooChooseLoader(Loader):
|
||||
self.entries = list(self.data.keys())
|
||||
|
||||
def _setup(self):
|
||||
if py7zr is None:
|
||||
raise RuntimeError("py7zr is required to unpack YooChoose dataset. Install py7zr first.")
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
zip_path = f"{self.root}/temp.7z"
|
||||
with requests.get(self.URL, stream=True) as r:
|
||||
with open(zip_path, 'wb') as f:
|
||||
for chunk in r.iter_content(8192): f.write(chunk)
|
||||
with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root)
|
||||
for chunk in r.iter_content(8192):
|
||||
f.write(chunk)
|
||||
with py7zr.SevenZipFile(zip_path, 'r') as z:
|
||||
z.extractall(self.root)
|
||||
os.remove(zip_path)
|
||||
|
||||
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel:
|
||||
|
||||
Reference in New Issue
Block a user