mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
40
engine/train_core.py
Normal file
40
engine/train_core.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .spec import TrainSpec
|
||||
from .telemetry.metrics import canonicalize_metrics
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainResult:
|
||||
spec: TrainSpec
|
||||
metrics: dict[str, Any]
|
||||
artifacts: dict[str, str]
|
||||
|
||||
|
||||
def run_train(spec: TrainSpec) -> TrainResult:
|
||||
cfg = spec.to_flat_dict()
|
||||
algo = spec.algorithm.name
|
||||
|
||||
if spec.runtime.use_jax or spec.runtime.backend == "jax":
|
||||
from .backends.jax import train_jax_backend
|
||||
|
||||
_, raw_metrics = train_jax_backend(cfg)
|
||||
elif algo == "qtable":
|
||||
from .backends.qtable import train_qtable
|
||||
|
||||
_, raw_metrics = train_qtable(cfg)
|
||||
else:
|
||||
from .backends.sb3 import train_sb3
|
||||
|
||||
_, raw_metrics = train_sb3(cfg)
|
||||
|
||||
metrics = canonicalize_metrics(raw_metrics, spec)
|
||||
artifacts: dict[str, str] = {}
|
||||
model_path = raw_metrics.get("model/path")
|
||||
if isinstance(model_path, str):
|
||||
artifacts["model/path"] = model_path
|
||||
|
||||
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)
|
||||
Reference in New Issue
Block a user