mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
@@ -1,38 +1,39 @@
|
||||
from .demand import estimate_demand, estimate_weighted_demand, generate_demand_for_actor
|
||||
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
|
||||
from .render import DashboardRenderer, style_axis
|
||||
from .wrappers import EconomicMetricsWrapper
|
||||
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
|
||||
from .providers import (
|
||||
ProviderBenchmark,
|
||||
ProviderResult,
|
||||
BenchmarkConfig,
|
||||
RandomBaseline,
|
||||
SurgeBaseline,
|
||||
)
|
||||
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
|
||||
from .discrete import EventQTable
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"estimate_demand",
|
||||
"estimate_weighted_demand",
|
||||
"generate_demand_for_actor",
|
||||
"sample_behavior",
|
||||
"get_transition_models",
|
||||
"trajectory_to_events",
|
||||
"DashboardRenderer",
|
||||
"style_axis",
|
||||
"EconomicMetricsWrapper",
|
||||
"MetricsCallback",
|
||||
"EvalMetricsCallback",
|
||||
"CheckpointArtifactCallback",
|
||||
"ProviderBenchmark",
|
||||
"ProviderResult",
|
||||
"BenchmarkConfig",
|
||||
"RandomBaseline",
|
||||
"SurgeBaseline",
|
||||
"compute_uplift_coi",
|
||||
"extract_purchases",
|
||||
"compute_agent_probability",
|
||||
"EventQTable",
|
||||
]
|
||||
from importlib import import_module
|
||||
|
||||
_EXPORTS: dict[str, tuple[str, str]] = {
|
||||
"estimate_demand": (".demand", "estimate_demand"),
|
||||
"estimate_weighted_demand": (".demand", "estimate_weighted_demand"),
|
||||
"generate_demand_for_actor": (".demand", "generate_demand_for_actor"),
|
||||
"sample_behavior": (".behavior", "sample_behavior"),
|
||||
"get_transition_models": (".behavior", "get_transition_models"),
|
||||
"trajectory_to_events": (".behavior", "trajectory_to_events"),
|
||||
"DashboardRenderer": (".render", "DashboardRenderer"),
|
||||
"style_axis": (".render", "style_axis"),
|
||||
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
|
||||
"MetricsCallback": (".callbacks", "MetricsCallback"),
|
||||
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
|
||||
"CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
|
||||
"ProviderBenchmark": (".providers", "ProviderBenchmark"),
|
||||
"ProviderResult": (".providers", "ProviderResult"),
|
||||
"BenchmarkConfig": (".providers", "BenchmarkConfig"),
|
||||
"RandomBaseline": (".providers", "RandomBaseline"),
|
||||
"SurgeBaseline": (".providers", "SurgeBaseline"),
|
||||
"compute_uplift_coi": (".coi", "compute_uplift_coi"),
|
||||
"extract_purchases": (".coi", "extract_purchases"),
|
||||
"compute_agent_probability": (".coi", "compute_agent_probability"),
|
||||
"EventQTable": (".discrete", "EventQTable"),
|
||||
}
|
||||
|
||||
__all__ = sorted(_EXPORTS)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name not in _EXPORTS:
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
module_name, attr_name = _EXPORTS[name]
|
||||
module = import_module(module_name, package=__name__)
|
||||
value = getattr(module, attr_name)
|
||||
globals()[name] = value
|
||||
return value
|
||||
|
||||
@@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback):
|
||||
t = self.num_timesteps
|
||||
|
||||
payload = {
|
||||
"economics/revenue": econ["revenue"],
|
||||
"economics/margin": econ["margin"],
|
||||
"coi/level": econ["coi_level"],
|
||||
"economics/regret": econ["regret"],
|
||||
"train/revenue_step": econ["revenue"],
|
||||
"train/margin_step": econ["margin"],
|
||||
"train/coi_level": econ["coi_level"],
|
||||
"train/regret_step": econ["regret"],
|
||||
}
|
||||
if "coi_mix" in econ:
|
||||
payload["coi/mix"] = econ["coi_mix"]
|
||||
payload["train/coi_mix"] = econ["coi_mix"]
|
||||
if "coi_base" in econ:
|
||||
payload["coi/base"] = econ["coi_base"]
|
||||
payload["train/coi_base"] = econ["coi_base"]
|
||||
if "coi_leakage" in econ:
|
||||
payload["coi/leakage"] = econ["coi_leakage"]
|
||||
payload["train/coi_leakage"] = econ["coi_leakage"]
|
||||
if "coi_penalty" in econ:
|
||||
payload["coi/penalty"] = econ["coi_penalty"]
|
||||
payload["train/coi_penalty"] = econ["coi_penalty"]
|
||||
wandb.log(payload, step=t)
|
||||
|
||||
self._episode_revenues.append(econ["revenue"])
|
||||
@@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback):
|
||||
return
|
||||
wandb.log(
|
||||
{
|
||||
"episode/mean_revenue": np.mean(self._episode_revenues),
|
||||
"episode/total_revenue": np.sum(self._episode_revenues),
|
||||
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
|
||||
"train/revenue_rollout_total": np.sum(self._episode_revenues),
|
||||
},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
@@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback):
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
wandb.log(
|
||||
{
|
||||
"eval/mean_reward": self.last_mean_reward,
|
||||
"eval/mean_revenue": np.mean(self._eval_revenues)
|
||||
"eval/reward_mean": self.last_mean_reward,
|
||||
"eval/revenue_mean": np.mean(self._eval_revenues)
|
||||
if self._eval_revenues
|
||||
else 0,
|
||||
},
|
||||
|
||||
101
engine/lib/tiers.py
Normal file
101
engine/lib/tiers.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PolicyLike(Protocol):
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True): ...
|
||||
|
||||
|
||||
class StaticPolicy:
|
||||
def __init__(self, n_actions: int):
|
||||
self._action = int(max(0, n_actions // 2))
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
return self._action, None
|
||||
|
||||
|
||||
class SurgePolicy:
|
||||
def __init__(
|
||||
self,
|
||||
n_actions: int,
|
||||
n_products: int,
|
||||
high_threshold: float = 60.0,
|
||||
low_threshold: float = 30.0,
|
||||
):
|
||||
self.n_actions = int(n_actions)
|
||||
self.n_products = int(n_products)
|
||||
self.mid = self.n_actions // 2
|
||||
self.high_t = float(high_threshold)
|
||||
self.low_t = float(low_threshold)
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
obs_arr = np.asarray(obs, dtype=np.float32)
|
||||
demand = obs_arr[: self.n_products]
|
||||
demand_mean = float(np.mean(demand)) if demand.size > 0 else 0.0
|
||||
if demand_mean >= self.high_t:
|
||||
return min(self.mid + 2, self.n_actions - 1), None
|
||||
if demand_mean <= self.low_t:
|
||||
return max(self.mid - 2, 0), None
|
||||
return self.mid, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearElasticityPolicy:
|
||||
n_actions: int
|
||||
n_products: int
|
||||
price_low: float
|
||||
price_high: float
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_actions = int(self.n_actions)
|
||||
self.n_products = int(self.n_products)
|
||||
self.price_low = float(self.price_low)
|
||||
self.price_high = float(self.price_high)
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
self._action_scales = np.linspace(0.8, 1.2, self.n_actions)
|
||||
|
||||
def fit(self, env, warmup_steps: int = 800, seed: int = 42):
|
||||
rng = np.random.default_rng(int(seed))
|
||||
obs, _ = env.reset(seed=int(seed))
|
||||
prices: list[float] = []
|
||||
demands: list[float] = []
|
||||
|
||||
for _ in range(int(max(10, warmup_steps))):
|
||||
action = int(rng.integers(0, self.n_actions))
|
||||
obs, _, term, trunc, info = env.step(action)
|
||||
done = bool(term or trunc)
|
||||
|
||||
p = np.asarray(info.get("prices", []), dtype=np.float32)
|
||||
d = np.asarray(info.get("demand", []), dtype=np.float32)
|
||||
if p.size > 0 and d.size > 0:
|
||||
prices.append(float(np.mean(p)))
|
||||
demands.append(float(np.mean(d)))
|
||||
|
||||
if done:
|
||||
obs, _ = env.reset()
|
||||
|
||||
if len(prices) < 8:
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
return self
|
||||
|
||||
slope, intercept = np.polyfit(np.asarray(prices), np.asarray(demands), 1)
|
||||
if slope < -1e-6:
|
||||
p_star = -intercept / (2.0 * slope)
|
||||
self._target_price = float(np.clip(p_star, self.price_low, self.price_high))
|
||||
else:
|
||||
self._target_price = 0.5 * (self.price_low + self.price_high)
|
||||
return self
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
obs_arr = np.asarray(obs, dtype=np.float32)
|
||||
cur_prices = obs_arr[self.n_products : 2 * self.n_products]
|
||||
cur_mean = (
|
||||
float(np.mean(cur_prices)) if cur_prices.size > 0 else self._target_price
|
||||
)
|
||||
scale = self._target_price / max(cur_mean, 1e-6)
|
||||
action = int(np.argmin(np.abs(self._action_scales - scale)))
|
||||
return int(np.clip(action, 0, self.n_actions - 1)), None
|
||||
Reference in New Issue
Block a user