mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -12,17 +12,14 @@ class TrainResult:
|
||||
spec: TrainSpec
|
||||
metrics: dict[str, Any]
|
||||
artifacts: dict[str, str]
|
||||
events: list[dict[str, Any]]
|
||||
|
||||
|
||||
def run_train(spec: TrainSpec) -> TrainResult:
|
||||
cfg = spec.to_flat_dict()
|
||||
algo = spec.algorithm.name
|
||||
|
||||
if spec.runtime.use_jax or spec.runtime.backend == "jax":
|
||||
from .backends.jax import train_jax_backend
|
||||
|
||||
_, raw_metrics = train_jax_backend(cfg)
|
||||
elif algo == "qtable":
|
||||
if algo == "qtable":
|
||||
from .backends.qtable import train_qtable
|
||||
|
||||
_, raw_metrics = train_qtable(cfg)
|
||||
@@ -31,10 +28,13 @@ def run_train(spec: TrainSpec) -> TrainResult:
|
||||
|
||||
_, raw_metrics = train_sb3(cfg)
|
||||
|
||||
events_raw = raw_metrics.pop("_train_events", [])
|
||||
events = [evt for evt in events_raw if isinstance(evt, dict)]
|
||||
|
||||
metrics = canonicalize_metrics(raw_metrics, spec)
|
||||
artifacts: dict[str, str] = {}
|
||||
model_path = raw_metrics.get("model/path")
|
||||
if isinstance(model_path, str):
|
||||
artifacts["model/path"] = model_path
|
||||
|
||||
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)
|
||||
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts, events=events)
|
||||
|
||||
Reference in New Issue
Block a user