refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

40
engine/train_core.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .spec import TrainSpec
from .telemetry.metrics import canonicalize_metrics
@dataclass(frozen=True)
class TrainResult:
spec: TrainSpec
metrics: dict[str, Any]
artifacts: dict[str, str]
def run_train(spec: TrainSpec) -> TrainResult:
cfg = spec.to_flat_dict()
algo = spec.algorithm.name
if spec.runtime.use_jax or spec.runtime.backend == "jax":
from .backends.jax import train_jax_backend
_, raw_metrics = train_jax_backend(cfg)
elif algo == "qtable":
from .backends.qtable import train_qtable
_, raw_metrics = train_qtable(cfg)
else:
from .backends.sb3 import train_sb3
_, raw_metrics = train_sb3(cfg)
metrics = canonicalize_metrics(raw_metrics, spec)
artifacts: dict[str, str] = {}
model_path = raw_metrics.get("model/path")
if isinstance(model_path, str):
artifacts["model/path"] = model_path
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)