mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback):
|
||||
self,
|
||||
log_histograms: bool = False,
|
||||
log_freq: int = 100,
|
||||
hist_freq: int = 500,
|
||||
step_offset: int = 0,
|
||||
verbose: int = 0,
|
||||
):
|
||||
super().__init__(verbose)
|
||||
self.log_histograms = log_histograms
|
||||
self.log_freq = max(1, int(log_freq))
|
||||
self.hist_freq = max(1, int(hist_freq))
|
||||
self.step_offset = max(0, int(step_offset))
|
||||
self._wandb = get_wandb_module()
|
||||
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
|
||||
self._price_samples: list[float] = []
|
||||
self._demand_samples: list[float] = []
|
||||
self._window_sums = {
|
||||
"train/revenue_mean": 0.0,
|
||||
"train/margin_mean": 0.0,
|
||||
@@ -74,35 +78,100 @@ class MetricsCallback(BaseCallback):
|
||||
)
|
||||
self._window_count += 1
|
||||
|
||||
def _flush(self, step: int) -> None:
|
||||
if self._window_count <= 0:
|
||||
def _accumulate_histograms(self, info: dict[str, Any]) -> None:
|
||||
if not self.log_histograms:
|
||||
return
|
||||
denom = float(self._window_count)
|
||||
payload = {
|
||||
key: (value / denom)
|
||||
for key, value in self._window_sums.items()
|
||||
if value != 0.0
|
||||
or key
|
||||
in {
|
||||
"train/revenue_mean",
|
||||
"train/margin_mean",
|
||||
"train/coi_level_mean",
|
||||
"train/regret_mean",
|
||||
|
||||
for key in ("effective_prices", "prices"):
|
||||
if key not in info:
|
||||
continue
|
||||
try:
|
||||
values = np.asarray(info.get(key), dtype=float).reshape(-1)
|
||||
except Exception:
|
||||
continue
|
||||
if values.size <= 0:
|
||||
continue
|
||||
finite_values = values[np.isfinite(values)]
|
||||
if finite_values.size > 0:
|
||||
self._price_samples.extend(finite_values.tolist())
|
||||
break
|
||||
|
||||
if "demand" in info:
|
||||
try:
|
||||
demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1)
|
||||
except Exception:
|
||||
demand_values = np.array([], dtype=float)
|
||||
if demand_values.size > 0:
|
||||
finite_demand = demand_values[np.isfinite(demand_values)]
|
||||
if finite_demand.size > 0:
|
||||
self._demand_samples.extend(finite_demand.tolist())
|
||||
|
||||
def _flush_histograms(self, step: int, force: bool = False) -> None:
|
||||
if not self.log_histograms:
|
||||
return
|
||||
if not force and step % self.hist_freq != 0:
|
||||
return
|
||||
if not self._price_samples and not self._demand_samples:
|
||||
return
|
||||
if self._wandb is None:
|
||||
self._price_samples.clear()
|
||||
self._demand_samples.clear()
|
||||
return
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
if self._price_samples:
|
||||
payload["train/price_dist"] = self._wandb.Histogram(
|
||||
np.asarray(self._price_samples, dtype=np.float32)
|
||||
)
|
||||
if self._demand_samples:
|
||||
payload["train/demand_dist"] = self._wandb.Histogram(
|
||||
np.asarray(self._demand_samples, dtype=np.float32)
|
||||
)
|
||||
|
||||
if payload and self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(payload, step=self.step_offset + int(step))
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
|
||||
self._price_samples.clear()
|
||||
self._demand_samples.clear()
|
||||
|
||||
def _flush(self, step: int, *, force_hist: bool = False) -> None:
|
||||
if self._window_count > 0:
|
||||
denom = float(self._window_count)
|
||||
payload = {
|
||||
key: (value / denom)
|
||||
for key, value in self._window_sums.items()
|
||||
if value != 0.0
|
||||
or key
|
||||
in {
|
||||
"train/revenue_mean",
|
||||
"train/margin_mean",
|
||||
"train/coi_level_mean",
|
||||
"train/regret_mean",
|
||||
}
|
||||
}
|
||||
}
|
||||
payload["train/global_step"] = int(step)
|
||||
if self._wandb_live:
|
||||
self._wandb.log(dict(payload), step=self.step_offset + int(step))
|
||||
else:
|
||||
self.events.append(payload)
|
||||
for key in self._window_sums:
|
||||
self._window_sums[key] = 0.0
|
||||
self._window_count = 0
|
||||
payload["train/global_step"] = int(step)
|
||||
if self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(dict(payload), step=self.step_offset + int(step))
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
self.events.append(payload)
|
||||
else:
|
||||
self.events.append(payload)
|
||||
for key in self._window_sums:
|
||||
self._window_sums[key] = 0.0
|
||||
self._window_count = 0
|
||||
|
||||
self._flush_histograms(step=step, force=force_hist)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for info in self.locals.get("infos", []):
|
||||
if isinstance(info, dict):
|
||||
self._accumulate(info)
|
||||
self._accumulate_histograms(info)
|
||||
|
||||
if self.num_timesteps % self.log_freq == 0:
|
||||
self._flush(step=self.num_timesteps)
|
||||
@@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback):
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
self._flush(step=self.num_timesteps)
|
||||
self._flush(step=self.num_timesteps, force_hist=True)
|
||||
|
||||
|
||||
class EvalMetricsCallback(EvalCallback):
|
||||
"""Deterministic evaluation collector detached from logging backends."""
|
||||
|
||||
def __init__(
|
||||
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
||||
self,
|
||||
eval_env,
|
||||
eval_freq: int = 1000,
|
||||
n_eval_episodes: int = 5,
|
||||
step_offset: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
||||
)
|
||||
self._eval_revenues: list[float] = []
|
||||
self.step_offset = max(0, int(step_offset))
|
||||
self._wandb = get_wandb_module()
|
||||
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
|
||||
self._eval_stats: dict[str, list[float]] = {
|
||||
"eval/revenue_mean": [],
|
||||
"eval/margin_mean": [],
|
||||
"eval/coi_level_mean": [],
|
||||
"eval/coi_leakage_mean": [],
|
||||
"eval/volatility_mean": [],
|
||||
"eval/agent_prob_mean": [],
|
||||
}
|
||||
self.events: list[dict[str, float | int]] = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
result = super()._on_step()
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
self.events.append(
|
||||
{
|
||||
"eval/reward_mean": float(self.last_mean_reward),
|
||||
"eval/revenue_mean": float(np.mean(self._eval_revenues))
|
||||
if self._eval_revenues
|
||||
else 0.0,
|
||||
"train/global_step": int(self.num_timesteps),
|
||||
}
|
||||
)
|
||||
self._eval_revenues = []
|
||||
payload: dict[str, float | int] = {
|
||||
"eval/reward_mean": float(self.last_mean_reward),
|
||||
"train/global_step": int(self.num_timesteps),
|
||||
}
|
||||
for key, values in self._eval_stats.items():
|
||||
payload[key] = float(np.mean(values)) if values else 0.0
|
||||
|
||||
if self._wandb_live:
|
||||
try:
|
||||
self._wandb.log(
|
||||
dict(payload),
|
||||
step=self.step_offset + int(self.num_timesteps),
|
||||
)
|
||||
except Exception:
|
||||
self._wandb_live = False
|
||||
self.events.append(payload)
|
||||
else:
|
||||
self.events.append(payload)
|
||||
|
||||
for values in self._eval_stats.values():
|
||||
values.clear()
|
||||
|
||||
return result
|
||||
|
||||
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
|
||||
# called after each eval episode
|
||||
info = locals_.get("info", {})
|
||||
if "economics" in info:
|
||||
self._eval_revenues.append(info["economics"]["revenue"])
|
||||
econ = info.get("economics") if isinstance(info, dict) else None
|
||||
if not isinstance(econ, dict):
|
||||
return
|
||||
|
||||
self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0)))
|
||||
self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0)))
|
||||
self._eval_stats["eval/coi_level_mean"].append(
|
||||
float(econ.get("coi_level", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/coi_leakage_mean"].append(
|
||||
float(econ.get("coi_leakage", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/volatility_mean"].append(
|
||||
float(econ.get("volatility", 0.0))
|
||||
)
|
||||
self._eval_stats["eval/agent_prob_mean"].append(
|
||||
float(econ.get("agent_prob", 0.0))
|
||||
)
|
||||
|
||||
@@ -156,14 +156,17 @@ class ProviderBenchmark:
|
||||
|
||||
# log to wandb if available
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
f"benchmark/{name}/revenue": result.mean_revenue,
|
||||
f"benchmark/{name}/coi_preserved": result.coi_preserved_pct,
|
||||
f"benchmark/{name}/margin": result.margin_integrity,
|
||||
"benchmark/alpha": alpha,
|
||||
}
|
||||
)
|
||||
try:
|
||||
wandb.log(
|
||||
{
|
||||
f"benchmark/{name}/revenue": result.mean_revenue,
|
||||
f"benchmark/{name}/coi_preserved": result.coi_preserved_pct,
|
||||
f"benchmark/{name}/margin": result.margin_integrity,
|
||||
"benchmark/alpha": alpha,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return self.results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user