updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -15,15 +15,19 @@ class MetricsCallback(BaseCallback):
self,
log_histograms: bool = False,
log_freq: int = 100,
hist_freq: int = 500,
step_offset: int = 0,
verbose: int = 0,
):
super().__init__(verbose)
self.log_histograms = log_histograms
self.log_freq = max(1, int(log_freq))
self.hist_freq = max(1, int(hist_freq))
self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._price_samples: list[float] = []
self._demand_samples: list[float] = []
self._window_sums = {
"train/revenue_mean": 0.0,
"train/margin_mean": 0.0,
@@ -74,35 +78,100 @@ class MetricsCallback(BaseCallback):
)
self._window_count += 1
def _flush(self, step: int) -> None:
if self._window_count <= 0:
def _accumulate_histograms(self, info: dict[str, Any]) -> None:
if not self.log_histograms:
return
denom = float(self._window_count)
payload = {
key: (value / denom)
for key, value in self._window_sums.items()
if value != 0.0
or key
in {
"train/revenue_mean",
"train/margin_mean",
"train/coi_level_mean",
"train/regret_mean",
for key in ("effective_prices", "prices"):
if key not in info:
continue
try:
values = np.asarray(info.get(key), dtype=float).reshape(-1)
except Exception:
continue
if values.size <= 0:
continue
finite_values = values[np.isfinite(values)]
if finite_values.size > 0:
self._price_samples.extend(finite_values.tolist())
break
if "demand" in info:
try:
demand_values = np.asarray(info.get("demand"), dtype=float).reshape(-1)
except Exception:
demand_values = np.array([], dtype=float)
if demand_values.size > 0:
finite_demand = demand_values[np.isfinite(demand_values)]
if finite_demand.size > 0:
self._demand_samples.extend(finite_demand.tolist())
def _flush_histograms(self, step: int, force: bool = False) -> None:
if not self.log_histograms:
return
if not force and step % self.hist_freq != 0:
return
if not self._price_samples and not self._demand_samples:
return
if self._wandb is None:
self._price_samples.clear()
self._demand_samples.clear()
return
payload: dict[str, Any] = {}
if self._price_samples:
payload["train/price_dist"] = self._wandb.Histogram(
np.asarray(self._price_samples, dtype=np.float32)
)
if self._demand_samples:
payload["train/demand_dist"] = self._wandb.Histogram(
np.asarray(self._demand_samples, dtype=np.float32)
)
if payload and self._wandb_live:
try:
self._wandb.log(payload, step=self.step_offset + int(step))
except Exception:
self._wandb_live = False
self._price_samples.clear()
self._demand_samples.clear()
def _flush(self, step: int, *, force_hist: bool = False) -> None:
if self._window_count > 0:
denom = float(self._window_count)
payload = {
key: (value / denom)
for key, value in self._window_sums.items()
if value != 0.0
or key
in {
"train/revenue_mean",
"train/margin_mean",
"train/coi_level_mean",
"train/regret_mean",
}
}
}
payload["train/global_step"] = int(step)
if self._wandb_live:
self._wandb.log(dict(payload), step=self.step_offset + int(step))
else:
self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0
payload["train/global_step"] = int(step)
if self._wandb_live:
try:
self._wandb.log(dict(payload), step=self.step_offset + int(step))
except Exception:
self._wandb_live = False
self.events.append(payload)
else:
self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0
self._flush_histograms(step=step, force=force_hist)
def _on_step(self) -> bool:
for info in self.locals.get("infos", []):
if isinstance(info, dict):
self._accumulate(info)
self._accumulate_histograms(info)
if self.num_timesteps % self.log_freq == 0:
self._flush(step=self.num_timesteps)
@@ -110,39 +179,81 @@ class MetricsCallback(BaseCallback):
return True
def _on_training_end(self) -> None:
self._flush(step=self.num_timesteps)
self._flush(step=self.num_timesteps, force_hist=True)
class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation collector detached from logging backends."""
def __init__(
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
self,
eval_env,
eval_freq: int = 1000,
n_eval_episodes: int = 5,
step_offset: int = 0,
**kwargs,
):
super().__init__(
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
)
self._eval_revenues: list[float] = []
self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._eval_stats: dict[str, list[float]] = {
"eval/revenue_mean": [],
"eval/margin_mean": [],
"eval/coi_level_mean": [],
"eval/coi_leakage_mean": [],
"eval/volatility_mean": [],
"eval/agent_prob_mean": [],
}
self.events: list[dict[str, float | int]] = []
def _on_step(self) -> bool:
result = super()._on_step()
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
self.events.append(
{
"eval/reward_mean": float(self.last_mean_reward),
"eval/revenue_mean": float(np.mean(self._eval_revenues))
if self._eval_revenues
else 0.0,
"train/global_step": int(self.num_timesteps),
}
)
self._eval_revenues = []
payload: dict[str, float | int] = {
"eval/reward_mean": float(self.last_mean_reward),
"train/global_step": int(self.num_timesteps),
}
for key, values in self._eval_stats.items():
payload[key] = float(np.mean(values)) if values else 0.0
if self._wandb_live:
try:
self._wandb.log(
dict(payload),
step=self.step_offset + int(self.num_timesteps),
)
except Exception:
self._wandb_live = False
self.events.append(payload)
else:
self.events.append(payload)
for values in self._eval_stats.values():
values.clear()
return result
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
# called after each eval episode
info = locals_.get("info", {})
if "economics" in info:
self._eval_revenues.append(info["economics"]["revenue"])
econ = info.get("economics") if isinstance(info, dict) else None
if not isinstance(econ, dict):
return
self._eval_stats["eval/revenue_mean"].append(float(econ.get("revenue", 0.0)))
self._eval_stats["eval/margin_mean"].append(float(econ.get("margin", 0.0)))
self._eval_stats["eval/coi_level_mean"].append(
float(econ.get("coi_level", 0.0))
)
self._eval_stats["eval/coi_leakage_mean"].append(
float(econ.get("coi_leakage", 0.0))
)
self._eval_stats["eval/volatility_mean"].append(
float(econ.get("volatility", 0.0))
)
self._eval_stats["eval/agent_prob_mean"].append(
float(econ.get("agent_prob", 0.0))
)