refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

18
engine/backends/jax.py Normal file
View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any, Mapping
from ..jax import JAX_AVAILABLE
def train_jax_backend(
cfg: Mapping[str, Any],
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
from ..jax.train import train_jax
return train_jax(dict(cfg))