flax>=0.8.0
optax>=0.2.0
distrax>=0.1.5
orbax-checkpoint>=0.5.0
chex>=0.1.8
