mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -1,150 +1,96 @@
|
||||
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
||||
"""Training callbacks with algorithm-agnostic metric extraction."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||
import numpy as np
|
||||
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
|
||||
class MetricsCallback(BaseCallback):
|
||||
"""Training metrics logger - reads info['economics'], logs to W&B."""
|
||||
"""Collects interval train metrics from env info dictionaries."""
|
||||
|
||||
def __init__(
|
||||
self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0
|
||||
self,
|
||||
log_histograms: bool = False,
|
||||
log_freq: int = 100,
|
||||
verbose: int = 0,
|
||||
):
|
||||
super().__init__(verbose)
|
||||
self.log_histograms = log_histograms
|
||||
self.log_freq = log_freq
|
||||
self._episode_revenues: list[float] = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return True
|
||||
|
||||
for info in self.locals.get("infos", []):
|
||||
if "economics" not in info:
|
||||
continue
|
||||
|
||||
econ = info["economics"]
|
||||
t = self.num_timesteps
|
||||
|
||||
payload = {
|
||||
"train/revenue_step": econ["revenue"],
|
||||
"train/margin_step": econ["margin"],
|
||||
"train/coi_level": econ["coi_level"],
|
||||
"train/regret_step": econ["regret"],
|
||||
}
|
||||
if "coi_mix" in econ:
|
||||
payload["train/coi_mix"] = econ["coi_mix"]
|
||||
if "coi_base" in econ:
|
||||
payload["train/coi_base"] = econ["coi_base"]
|
||||
if "coi_leakage" in econ:
|
||||
payload["train/coi_leakage"] = econ["coi_leakage"]
|
||||
if "coi_penalty" in econ:
|
||||
payload["train/coi_penalty"] = econ["coi_penalty"]
|
||||
wandb.log(payload, step=t)
|
||||
|
||||
self._episode_revenues.append(econ["revenue"])
|
||||
|
||||
# histograms at log_freq intervals
|
||||
if self.log_histograms and self.num_timesteps % self.log_freq == 0:
|
||||
for info in self.locals.get("infos", []):
|
||||
if "prices" in info:
|
||||
wandb.log(
|
||||
{"distributions/prices": wandb.Histogram(info["prices"])},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
if "demand" in info:
|
||||
wandb.log(
|
||||
{"distributions/demand": wandb.Histogram(info["demand"])},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _on_rollout_end(self) -> None:
|
||||
if not HAS_WANDB or wandb.run is None or not self._episode_revenues:
|
||||
return
|
||||
wandb.log(
|
||||
{
|
||||
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
|
||||
"train/revenue_rollout_total": np.sum(self._episode_revenues),
|
||||
},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
self._episode_revenues = []
|
||||
|
||||
|
||||
class CheckpointArtifactCallback(BaseCallback):
|
||||
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
|
||||
|
||||
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
|
||||
super().__init__(verbose)
|
||||
self.cfg = dict(cfg)
|
||||
self.interval = max(1, int(interval))
|
||||
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._next_checkpoint = self.interval
|
||||
self._last_saved_step = -1
|
||||
|
||||
def _artifact_name(self) -> str:
|
||||
sweep_id = (
|
||||
getattr(wandb.run, "sweep_id", None)
|
||||
if HAS_WANDB and wandb.run is not None
|
||||
else None
|
||||
)
|
||||
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
|
||||
|
||||
def _checkpoint_file(self) -> Path:
|
||||
algo = str(self.cfg.get("algo", "model"))
|
||||
base = self.model_dir / f"phantom_{algo}_checkpoint"
|
||||
self.model.save(str(base))
|
||||
return base.with_suffix(".zip")
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return
|
||||
step = int(self.num_timesteps)
|
||||
if step <= self._last_saved_step:
|
||||
return
|
||||
checkpoint_path = self._checkpoint_file()
|
||||
metadata = {
|
||||
"step": step,
|
||||
"algo": str(self.cfg.get("algo", "unknown")),
|
||||
"sweep_id": getattr(wandb.run, "sweep_id", None),
|
||||
self.log_freq = max(1, int(log_freq))
|
||||
self._window_sums = {
|
||||
"train/revenue_mean": 0.0,
|
||||
"train/margin_mean": 0.0,
|
||||
"train/coi_level_mean": 0.0,
|
||||
"train/regret_mean": 0.0,
|
||||
"train/coi_mix": 0.0,
|
||||
"train/coi_base": 0.0,
|
||||
"train/coi_leakage": 0.0,
|
||||
"train/coi_penalty": 0.0,
|
||||
}
|
||||
saved = log_checkpoint_file(
|
||||
self._artifact_name(),
|
||||
file_path=checkpoint_path,
|
||||
artifact_file_name=checkpoint_path.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
if saved:
|
||||
self._last_saved_step = step
|
||||
self._window_count = 0
|
||||
self.events: list[dict[str, Any]] = []
|
||||
|
||||
def _accumulate(self, info: dict[str, Any]) -> None:
|
||||
econ = info.get("economics")
|
||||
if not isinstance(econ, dict):
|
||||
return
|
||||
self._window_sums["train/revenue_mean"] += float(econ.get("revenue", 0.0))
|
||||
self._window_sums["train/margin_mean"] += float(econ.get("margin", 0.0))
|
||||
self._window_sums["train/coi_level_mean"] += float(econ.get("coi_level", 0.0))
|
||||
self._window_sums["train/regret_mean"] += float(econ.get("regret", 0.0))
|
||||
if "coi_mix" in econ:
|
||||
self._window_sums["train/coi_mix"] += float(econ.get("coi_mix", 0.0))
|
||||
if "coi_base" in econ:
|
||||
self._window_sums["train/coi_base"] += float(econ.get("coi_base", 0.0))
|
||||
if "coi_leakage" in econ:
|
||||
self._window_sums["train/coi_leakage"] += float(
|
||||
econ.get("coi_leakage", 0.0)
|
||||
)
|
||||
if "coi_penalty" in econ:
|
||||
self._window_sums["train/coi_penalty"] += float(
|
||||
econ.get("coi_penalty", 0.0)
|
||||
)
|
||||
self._window_count += 1
|
||||
|
||||
def _flush(self, step: int) -> None:
|
||||
if self._window_count <= 0:
|
||||
return
|
||||
denom = float(self._window_count)
|
||||
payload = {
|
||||
key: (value / denom)
|
||||
for key, value in self._window_sums.items()
|
||||
if value != 0.0
|
||||
or key
|
||||
in {
|
||||
"train/revenue_mean",
|
||||
"train/margin_mean",
|
||||
"train/coi_level_mean",
|
||||
"train/regret_mean",
|
||||
}
|
||||
}
|
||||
payload["train/global_step"] = int(step)
|
||||
self.events.append(payload)
|
||||
for key in self._window_sums:
|
||||
self._window_sums[key] = 0.0
|
||||
self._window_count = 0
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.num_timesteps < self._next_checkpoint:
|
||||
return True
|
||||
self._save_checkpoint()
|
||||
while self._next_checkpoint <= self.num_timesteps:
|
||||
self._next_checkpoint += self.interval
|
||||
for info in self.locals.get("infos", []):
|
||||
if isinstance(info, dict):
|
||||
self._accumulate(info)
|
||||
|
||||
if self.num_timesteps % self.log_freq == 0:
|
||||
self._flush(step=self.num_timesteps)
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
self._save_checkpoint()
|
||||
self._flush(step=self.num_timesteps)
|
||||
|
||||
|
||||
class EvalMetricsCallback(EvalCallback):
|
||||
"""Deterministic evaluation - true performance without exploration noise."""
|
||||
"""Deterministic evaluation collector detached from logging backends."""
|
||||
|
||||
def __init__(
|
||||
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
||||
@@ -153,23 +99,19 @@ class EvalMetricsCallback(EvalCallback):
|
||||
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
||||
)
|
||||
self._eval_revenues: list[float] = []
|
||||
self.events: list[dict[str, float | int]] = []
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
result = super()._on_step()
|
||||
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return result
|
||||
|
||||
# log eval metrics after evaluation runs
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
wandb.log(
|
||||
self.events.append(
|
||||
{
|
||||
"eval/reward_mean": self.last_mean_reward,
|
||||
"eval/revenue_mean": np.mean(self._eval_revenues)
|
||||
"eval/reward_mean": float(self.last_mean_reward),
|
||||
"eval/revenue_mean": float(np.mean(self._eval_revenues))
|
||||
if self._eval_revenues
|
||||
else 0,
|
||||
},
|
||||
step=self.num_timesteps,
|
||||
else 0.0,
|
||||
"train/global_step": int(self.num_timesteps),
|
||||
}
|
||||
)
|
||||
self._eval_revenues = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user