Flax is a high-performance neural network library for JAX that is designed for
flexibility: Try new forms of training by forking an example and by modifying
the training loop, not by adding features to a framework.

WWW: https://github.com/google/flax
