cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -12,17 +12,14 @@ class TrainResult:
spec: TrainSpec
metrics: dict[str, Any]
artifacts: dict[str, str]
events: list[dict[str, Any]]
def run_train(spec: TrainSpec) -> TrainResult:
cfg = spec.to_flat_dict()
algo = spec.algorithm.name
if spec.runtime.use_jax or spec.runtime.backend == "jax":
from .backends.jax import train_jax_backend
_, raw_metrics = train_jax_backend(cfg)
elif algo == "qtable":
if algo == "qtable":
from .backends.qtable import train_qtable
_, raw_metrics = train_qtable(cfg)
@@ -31,10 +28,13 @@ def run_train(spec: TrainSpec) -> TrainResult:
_, raw_metrics = train_sb3(cfg)
events_raw = raw_metrics.pop("_train_events", [])
events = [evt for evt in events_raw if isinstance(evt, dict)]
metrics = canonicalize_metrics(raw_metrics, spec)
artifacts: dict[str, str] = {}
model_path = raw_metrics.get("model/path")
if isinstance(model_path, str):
artifacts["model/path"] = model_path
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts, events=events)