mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
19 lines
486 B
Python
19 lines
486 B
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Mapping
|
|
|
|
from ..jax import JAX_AVAILABLE
|
|
|
|
|
|
def train_jax_backend(
|
|
cfg: Mapping[str, Any],
|
|
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
|
|
if not JAX_AVAILABLE:
|
|
raise ImportError(
|
|
"JAX backend requested but JAX is not installed. "
|
|
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
|
)
|
|
from ..jax.train import train_jax
|
|
|
|
return train_jax(dict(cfg))
|