feature: telemetry logging

This commit is contained in:
2026-03-10 14:23:17 +01:00
parent be03b2d4d5
commit 4c7d911043
14 changed files with 454 additions and 104 deletions

View File

@@ -5,6 +5,8 @@ from typing import Any
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import numpy as np
from ..telemetry.wandb import get_wandb_module
class MetricsCallback(BaseCallback):
"""Collects interval train metrics from env info dictionaries."""
@@ -13,16 +15,25 @@ class MetricsCallback(BaseCallback):
self,
log_histograms: bool = False,
log_freq: int = 100,
step_offset: int = 0,
verbose: int = 0,
):
super().__init__(verbose)
self.log_histograms = log_histograms
self.log_freq = max(1, int(log_freq))
self.step_offset = max(0, int(step_offset))
self._wandb = get_wandb_module()
self._wandb_live = bool(self._wandb is not None and self._wandb.run is not None)
self._window_sums = {
"train/revenue_mean": 0.0,
"train/margin_mean": 0.0,
"train/coi_level_mean": 0.0,
"train/regret_mean": 0.0,
"train/profit_mean": 0.0,
"train/agent_prob": 0.0,
"train/alpha_adv": 0.0,
"train/ux_penalty": 0.0,
"train/volatility": 0.0,
"train/coi_mix": 0.0,
"train/coi_base": 0.0,
"train/coi_leakage": 0.0,
@@ -39,6 +50,16 @@ class MetricsCallback(BaseCallback):
self._window_sums["train/margin_mean"] += float(econ.get("margin", 0.0))
self._window_sums["train/coi_level_mean"] += float(econ.get("coi_level", 0.0))
self._window_sums["train/regret_mean"] += float(econ.get("regret", 0.0))
if "profit" in econ:
self._window_sums["train/profit_mean"] += float(econ.get("profit", 0.0))
if "agent_prob" in econ:
self._window_sums["train/agent_prob"] += float(econ.get("agent_prob", 0.0))
if "alpha_adv" in econ:
self._window_sums["train/alpha_adv"] += float(econ.get("alpha_adv", 0.0))
if "ux_penalty" in econ:
self._window_sums["train/ux_penalty"] += float(econ.get("ux_penalty", 0.0))
if "volatility" in econ:
self._window_sums["train/volatility"] += float(econ.get("volatility", 0.0))
if "coi_mix" in econ:
self._window_sums["train/coi_mix"] += float(econ.get("coi_mix", 0.0))
if "coi_base" in econ:
@@ -70,7 +91,10 @@ class MetricsCallback(BaseCallback):
}
}
payload["train/global_step"] = int(step)
self.events.append(payload)
if self._wandb_live:
self._wandb.log(dict(payload), step=self.step_offset + int(step))
else:
self.events.append(payload)
for key in self._window_sums:
self._window_sums[key] = 0.0
self._window_count = 0

View File

@@ -57,7 +57,21 @@ class EconomicMetricsWrapper(gym.Wrapper):
"coi_level": coi_level,
"regret": regret,
}
for key in ("coi_mix", "coi_base", "coi_leakage", "coi_penalty"):
for key in (
"coi_mix",
"coi_base",
"coi_leakage",
"coi_penalty",
"ux_penalty",
"volatility",
"profit",
"cost_floor",
"reward_revenue",
"reward_total",
"agent_prob",
"alpha_adv",
"alpha_nominal",
):
if key in info:
info["economics"][key] = info[key]
info["prices"] = prices.copy()