Files
PHANTOM/engine/jax/__init__.py

14 lines
281 B
Python

"""JAX-compatible training and environment modules for PHANTOM."""
from __future__ import annotations
try:
import jax # noqa: F401
import jax.numpy as jnp # noqa: F401
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
__all__ = ["JAX_AVAILABLE"]