refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -0,0 +1,23 @@
from .metrics import canonicalize_metrics
from .wandb import (
current_config,
finish_run,
get_wandb_module,
init_run,
log_metrics,
run_agent,
update_run_config,
update_summary,
)
__all__ = [
"canonicalize_metrics",
"current_config",
"finish_run",
"get_wandb_module",
"init_run",
"log_metrics",
"run_agent",
"update_run_config",
"update_summary",
]

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from typing import Any, Mapping
from ..spec import TrainSpec
_ALIASES = {
"train/reward": "train/reward_mean",
"train/revenue": "train/revenue_mean",
"train/dqn_loss": "train/loss",
"eval/reward": "eval/reward_mean",
"eval/revenue": "eval/revenue_mean",
"train/steps_per_second": "runtime/steps_per_second",
}
def _as_float(value: Any, default: float | None = None) -> float | None:
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, Any]:
metrics: dict[str, Any] = {}
for key, value in raw.items():
canonical = _ALIASES.get(str(key), str(key))
if canonical in metrics and canonical != key:
continue
metrics[canonical] = value
metrics.setdefault("train/global_step", spec.runtime.total_timesteps)
eval_reward = _as_float(metrics.get("eval/reward_mean"), 0.0) or 0.0
eval_revenue = _as_float(metrics.get("eval/revenue_mean"), 0.0) or 0.0
metrics["objective/score"] = eval_reward + spec.study.revenue_weight * eval_revenue
margin_mean = _as_float(metrics.get("eval/margin_mean"), None)
if margin_mean is not None:
metrics["objective/constraint_margin"] = margin_mean - spec.env.margin_floor
coi_level = _as_float(metrics.get("eval/coi_level_mean"), None)
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
metrics["study/alpha"] = spec.study.alpha
metrics["study/lambda_coi"] = spec.study.lambda_coi
metrics["study/robust_radius"] = spec.study.robust_radius
metrics["study/info_value"] = spec.study.info_value
metrics["runtime/backend"] = spec.runtime.backend
metrics["runtime/device"] = spec.runtime.device
metrics["runtime/seed"] = spec.runtime.seed
return metrics

98
engine/telemetry/wandb.py Normal file
View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any, Callable, Iterable, Mapping
def get_wandb_module():
try:
import wandb
return wandb
except ImportError:
return None
def _require_wandb():
wandb = get_wandb_module()
if wandb is None:
raise ImportError("wandb is required for this workflow")
return wandb
def init_run(
*,
mode: str,
project: str | None = None,
config: Mapping[str, Any] | None = None,
name: str | None = None,
tags: Iterable[str] | None = None,
group: str | None = None,
sweep_mode: bool = False,
):
wandb = _require_wandb()
kwargs: dict[str, Any] = {"mode": mode}
if group:
kwargs["group"] = group
if sweep_mode:
run = wandb.init(**kwargs)
if name and run is not None:
run.name = name
return run
init_kwargs = dict(kwargs)
init_kwargs["project"] = project
if config is not None:
init_kwargs["config"] = dict(config)
if name:
init_kwargs["name"] = name
if tags:
init_kwargs["tags"] = list(tags)
return wandb.init(**init_kwargs)
def finish_run() -> None:
wandb = get_wandb_module()
if wandb is not None and wandb.run is not None:
wandb.finish()
def current_config() -> dict[str, Any]:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return {}
return {key: wandb.config[key] for key in wandb.config.keys()}
def update_run_config(config: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
try:
wandb.config.update(dict(config), allow_val_change=True)
except TypeError:
wandb.config.update(dict(config))
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
wandb.log(dict(metrics), step=step)
def update_summary(metrics: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
for key, value in metrics.items():
wandb.run.summary[key] = value
def run_agent(
sweep_id: str,
fn: Callable[[], None],
*,
count: int | None = None,
) -> None:
wandb = _require_wandb()
wandb.agent(sweep_id, function=fn, count=count)