mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: training and data refactors
This commit is contained in:
@@ -4,16 +4,17 @@ from pathlib import Path
|
|||||||
from typing import Dict, Type, Optional
|
from typing import Dict, Type, Optional
|
||||||
import pickle
|
import pickle
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from environment import PHANTOMEnv, BusinessLogicConstraints
|
from sim.rl.environment import PHANTOMEnv, BusinessLogicConstraints
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
|
from sim.rl.engine import (BasePricingEngine, WildPricingEngine, StaticPricingEngine,
|
||||||
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
|
SimpleDemandEngine, RandomWalkEngine, ThompsonSamplingEngine)
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
BasePricingEngine = None # engines not required for basic usage
|
BasePricingEngine = None # engines not required for basic usage
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -36,27 +37,49 @@ class EngineTrainer:
|
|||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
def train(self, n_episodes: int, seed: int = 42):
|
def train(self, n_episodes: int, seed: int = 42):
|
||||||
obs, _ = self.env.reset(seed=seed)
|
|
||||||
prices = None
|
|
||||||
for ep in range(n_episodes):
|
for ep in range(n_episodes):
|
||||||
prices = self.engine.compute_prices(prices, obs)
|
obs, _ = self.env.reset(seed=seed + ep)
|
||||||
obs, reward, done, _, info = self.env.step(prices)
|
self.engine.reset()
|
||||||
self.engine.update(obs, reward, done, info)
|
done = False
|
||||||
|
prev_prices = obs["elasticity"]["price"]
|
||||||
|
episode_reward = 0.0
|
||||||
|
last_info: Dict[str, float] = {}
|
||||||
|
while not done:
|
||||||
|
action_prices = self.engine.compute_prices(prev_prices, obs)
|
||||||
|
obs, reward, done, _, info = self.env.step(action_prices)
|
||||||
|
self.engine.update(obs, reward, done, info)
|
||||||
|
episode_reward += reward
|
||||||
|
prev_prices = obs["elasticity"]["price"]
|
||||||
|
last_info = info
|
||||||
|
if self.tb_writer:
|
||||||
|
self.tb_writer.add_scalar("reward/step", reward, self.global_step)
|
||||||
|
if "coi" in info:
|
||||||
|
self.tb_writer.add_scalar("diagnostics/coi", info["coi"], self.global_step)
|
||||||
|
if "alpha_hat" in info:
|
||||||
|
self.tb_writer.add_scalar("diagnostics/alpha_hat", info["alpha_hat"], self.global_step)
|
||||||
|
self.global_step += 1
|
||||||
|
last_info = dict(last_info)
|
||||||
|
last_info.update({"episode_reward": episode_reward, "episode": ep})
|
||||||
|
self.episode_metrics.append(last_info)
|
||||||
|
if self.tb_writer:
|
||||||
|
self.tb_writer.add_scalar("reward/episode", episode_reward, ep)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def run_episode(self, seed: int = 42) -> Dict:
|
def run_episode(self, seed: int = 42) -> Dict:
|
||||||
"""run single evaluation episode and return metrics"""
|
"""run single evaluation episode and return metrics"""
|
||||||
obs, _ = self.env.reset(seed=seed)
|
obs, _ = self.env.reset(seed=seed)
|
||||||
self.engine.reset()
|
self.engine.reset()
|
||||||
total_reward, prices = 0.0, None
|
total_reward = 0.0
|
||||||
|
prev_prices = obs["elasticity"]["price"]
|
||||||
ep_metrics = {'total_reward': 0.0}
|
ep_metrics = {'total_reward': 0.0}
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
prices = self.engine.compute_prices(prices, obs) if prices is not None else obs["elasticity"]["price"]
|
action_prices = self.engine.compute_prices(prev_prices, obs)
|
||||||
obs, reward, done, _, info = self.env.step(prices)
|
obs, reward, done, _, info = self.env.step(action_prices)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
for k, v in info.items():
|
for k, v in info.items():
|
||||||
ep_metrics[k] = v
|
ep_metrics[k] = v
|
||||||
|
prev_prices = obs["elasticity"]["price"]
|
||||||
ep_metrics['total_reward'] = total_reward
|
ep_metrics['total_reward'] = total_reward
|
||||||
return ep_metrics
|
return ep_metrics
|
||||||
|
|
||||||
@@ -106,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
logger.error("Engines not available, cannot run training")
|
logger.error("Engines not available, cannot run training")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
base_dir = Path("./runs")
|
base_dir = Path("./sim/rl/runs")
|
||||||
base_dir.mkdir(exist_ok=True)
|
base_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
engines = {
|
engines = {
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
import os, requests, py7zr
|
import os
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
import py7zr # type: ignore
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
py7zr = None
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
try:
|
try:
|
||||||
@@ -22,12 +27,16 @@ class YooChooseLoader(Loader):
|
|||||||
self.entries = list(self.data.keys())
|
self.entries = list(self.data.keys())
|
||||||
|
|
||||||
def _setup(self):
|
def _setup(self):
|
||||||
|
if py7zr is None:
|
||||||
|
raise RuntimeError("py7zr is required to unpack YooChoose dataset. Install py7zr first.")
|
||||||
os.makedirs(self.root, exist_ok=True)
|
os.makedirs(self.root, exist_ok=True)
|
||||||
zip_path = f"{self.root}/temp.7z"
|
zip_path = f"{self.root}/temp.7z"
|
||||||
with requests.get(self.URL, stream=True) as r:
|
with requests.get(self.URL, stream=True) as r:
|
||||||
with open(zip_path, 'wb') as f:
|
with open(zip_path, 'wb') as f:
|
||||||
for chunk in r.iter_content(8192): f.write(chunk)
|
for chunk in r.iter_content(8192):
|
||||||
with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root)
|
f.write(chunk)
|
||||||
|
with py7zr.SevenZipFile(zip_path, 'r') as z:
|
||||||
|
z.extractall(self.root)
|
||||||
os.remove(zip_path)
|
os.remove(zip_path)
|
||||||
|
|
||||||
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel:
|
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user