fixing models for gcp

This commit is contained in:
2026-02-17 16:54:55 +01:00
parent 802f31b4a1
commit 9acc998cc9
5 changed files with 497 additions and 193 deletions

View File

@@ -2,7 +2,7 @@ from .demand import estimate_demand, estimate_weighted_demand, generate_demand_f
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
from .render import DashboardRenderer, style_axis
from .wrappers import EconomicMetricsWrapper
from .callbacks import MetricsCallback, EvalMetricsCallback
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
from .providers import (
ProviderBenchmark,
ProviderResult,

View File

@@ -1,8 +1,12 @@
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
from pathlib import Path
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
import numpy as np
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
try:
import wandb
@@ -80,6 +84,65 @@ class MetricsCallback(BaseCallback):
self._episode_revenues = []
class CheckpointArtifactCallback(BaseCallback):
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
super().__init__(verbose)
self.cfg = dict(cfg)
self.interval = max(1, int(interval))
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
self.model_dir.mkdir(parents=True, exist_ok=True)
self._next_checkpoint = self.interval
self._last_saved_step = -1
def _artifact_name(self) -> str:
sweep_id = (
getattr(wandb.run, "sweep_id", None)
if HAS_WANDB and wandb.run is not None
else None
)
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
def _checkpoint_file(self) -> Path:
algo = str(self.cfg.get("algo", "model"))
base = self.model_dir / f"phantom_{algo}_checkpoint"
self.model.save(str(base))
return base.with_suffix(".zip")
def _save_checkpoint(self) -> None:
if not HAS_WANDB or wandb.run is None:
return
step = int(self.num_timesteps)
if step <= self._last_saved_step:
return
checkpoint_path = self._checkpoint_file()
metadata = {
"step": step,
"algo": str(self.cfg.get("algo", "unknown")),
"sweep_id": getattr(wandb.run, "sweep_id", None),
}
saved = log_checkpoint_file(
self._artifact_name(),
file_path=checkpoint_path,
artifact_file_name=checkpoint_path.name,
metadata=metadata,
)
if saved:
self._last_saved_step = step
def _on_step(self) -> bool:
if self.num_timesteps < self._next_checkpoint:
return True
self._save_checkpoint()
while self._next_checkpoint <= self.num_timesteps:
self._next_checkpoint += self.interval
return True
def _on_training_end(self) -> None:
self._save_checkpoint()
class EvalMetricsCallback(EvalCallback):
"""Deterministic evaluation - true performance without exploration noise."""