diff --git a/sim/rl/jax_core/__init__.py b/sim/rl/jax_core/__init__.py new file mode 100644 index 0000000..99d5a87 --- /dev/null +++ b/sim/rl/jax_core/__init__.py @@ -0,0 +1,11 @@ +"""JAX-accelerated simulation core for PHANTOM environment.""" +from .transitions import TransitionData, compile_transitions, fallback_transitions, JAX_AVAILABLE +from .simulation import SessionBatch, SimResult, sample_sessions, compute_metrics +from .features import session_features, compute_session_transitions +from .separability import compute_divergences, estimate_alpha_batch + +__all__ = [ + "JAX_AVAILABLE", "TransitionData", "compile_transitions", "fallback_transitions", + "SessionBatch", "SimResult", "sample_sessions", "compute_metrics", + "session_features", "compute_session_transitions", "compute_divergences", "estimate_alpha_batch", +]