Google Introduces Flax: A Neural Network Library for JAX

Google recently introduce Flax — a neural network library for JAX that is designed for flexibility. Flax can train neural networks by forking an example from its official GitHub repository. When it comes to modifying models, developers need no longer add features to the framework, they can simply modify the training loop (such as train_step setting) to achieve the same result. At its core, Flax is built around parameterised functions called Modules, which override apply and can be used as normal functions. Google Neural Network Library Flax


