flax>=0.8.0 optax>=0.2.0 distrax>=0.1.5 orbax-checkpoint>=0.5.0 chex>=0.1.8