mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
18
engine/backends/jax.py
Normal file
18
engine/backends/jax.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..jax import JAX_AVAILABLE
|
||||
|
||||
|
||||
def train_jax_backend(
|
||||
cfg: Mapping[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
|
||||
if not JAX_AVAILABLE:
|
||||
raise ImportError(
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
from ..jax.train import train_jax
|
||||
|
||||
return train_jax(dict(cfg))
|
||||
Reference in New Issue
Block a user