Awesome JAX
JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
Libraries
- Neural Network Libraries
-
Flax - Centered on flexibility and clarity.
-
Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
-
Objax - Has an object oriented design similar to PyTorch.
-
Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
-
Trax - “Batteries included” deep learning library focused on providing solutions for common workloads.
-
Jraph - Lightweight graph neural network library.
-
Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
-
HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
-
Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
-
Scenic - A Jax Library for Computer Vision Research and Beyond.
-
Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
-
EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
-
NumPyro - Probabilistic programming based on the Pyro library.
-
Chex - Utilities to write and test reliable JAX code.
-
Optax - Gradient processing and optimization library.
-
RLax - Library for implementing reinforcement learning agents.
-
JAX, M.D. - Accelerated, differential molecular dynamics.
-
Coax - Turn RL papers into code, the easy way.
-
Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
-
cvxpylayers - Construct differentiable convex optimization layers.
-
TensorLy - Tensor learning made simple.
-
NetKet - Machine Learning toolbox for Quantum Physics.
-
Fortuna - AWS library for Uncertainty Quantification in Deep Learning.
-
BlackJAX - Library of samplers for JAX.
New Libraries
This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
-
FedJAX - Federated learning in JAX, built on Optax and Haiku.
-
Equivariant MLP - Construct equivariant neural network layers.
-
jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
-
Parallax - Immutable Torch Modules for JAX.
-
jax-unirep - Library implementing the UniRep model for protein machine learning applications.
-
jax-flows - Normalizing flows in JAX.
-
sklearn-jax-kernels -
scikit-learn
kernel matrices using JAX.
-
jax-cosmo - Differentiable cosmology library.
-
efax - Exponential Families in JAX.
-
mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
-
imax - Image augmentations and transformations.
-
FlaxVision - Flax version of TorchVision.
-
Oryx - Probabilistic programming language based on program transformations.
-
Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
-
delta PV - A photovoltaic simulator with automatic differentation.
-
jaxlie - Lie theory library for rigid body transformations and optimization.
-
BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
-
flaxmodels - Pretrained models for Jax/Flax.
-
CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
-
exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
-
JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
-
PIX - PIX is an image processing library in JAX, for JAX.
-
bayex - Bayesian Optimization powered by JAX.
-
JaxDF - Framework for differentiable simulators with arbitrary discretizations.
-
tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
-
jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
-
PGMax - A framework for building discrete Probabilistic Graphical Models (PGM’s) and running inference inference on them via JAX.
-
EvoJAX - Hardware-Accelerated Neuroevolution
-
evosax - JAX-Based Evolution Strategies
-
SymJAX - Symbolic CPU/GPU/TPU programming.
-
mcx - Express & compile probabilistic programs for performant inference.
-
Einshape - DSL-based reshaping library for JAX and other frameworks.
-
ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
-
Diffrax - Numerical differential equation solvers in JAX.
-
tinygp - The tiniest of Gaussian process libraries in JAX.
-
gymnax - Reinforcement Learning Environments with the well-known gym API.
-
Mctx - Monte Carlo tree search algorithms in native JAX.
-
KFAC-JAX - Second Order Optimization with Approximate Curvature for NNs.
-
TF2JAX - Convert functions/graphs to JAX functions.
-
jwave - A library for differentiable acoustic simulations
-
GPJax - Gaussian processes in JAX.
-
Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
-
Eqxvision - Equinox version of Torchvision.
-
JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
-
econpizza - Solve macroeconomic models with hetereogeneous agents using JAX.
-
SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
-
jax-tqdm - Add a tqdm progress bar to JAX scans and loops.
-
safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗
safetensors
.
-
Kernex - Differentiable stencil decorators in JAX.
-
MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
-
Pax - A Jax-based machine learning framework for training large scale models.
-
Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects.
-
purejaxrl - Vectorisable, end-to-end RL algorithms in JAX.
-
Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
-
SCICO - Scientific computational imaging in JAX.
-
Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
-
BrainPy - Brain Dynamics Programming in Python.
-
OTT-JAX - Optimal transport tools in JAX.
-
QDax - Quality Diversity optimization in Jax.
-
JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
-
Pgx - Vectorized board game environments for RL with an AlphaZero example.
-
EasyDeL - EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
-
XLB - A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
-
dynamiqs - High-performance and differentiable simulations of quantum systems with JAX.
Models and Projects
JAX
Flax
Haiku
Trax
-
Reformer - Implementation of the Reformer (efficient transformer) architecture.
NumPyro
Videos
-
NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
-
Introduction to JAX - Simple neural network from scratch in JAX.
-
[JAX: Accelerated Machine Learning Research |
SciPy 2020 |
VanderPlas](https://youtu.be/z-WSrQDXkuM) - JAX’s core design, how it’s powering new research, and how you can start using it. |
-
Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.
-
[JAX: Accelerated machine-learning research via composable function transformations in Python |
NeurIPS 2019 |
Skye Wanderman-Milne](https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python) - JAX intro presentation in Program Transformations for Machine Learning workshop. |
-
[JAX on Cloud TPUs |
NeurIPS 2020 |
Skye Wanderman-Milne and James Bradbury](https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit) - Presentation of TPU host access with demo. |
-
[Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
NeurIPS 2020](https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond) - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers. |
-
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
-
JAX, Flax & Transformers 🤗 - 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
-
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - White paper describing an early version of JAX, detailing how computation is traced and compiled.
-
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
-
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - Uses JAX’s JIT and VMAP to achieve faster differentially private than existing libraries.
-
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. - White paper describing the XLB library: benchmarks, validations, and more details about the library.
Tutorials and Blog Posts
-
Using JAX to accelerate our research by David Budden and Matteo Hessel - Describes the state of JAX and the JAX ecosystem at DeepMind.
-
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural network building blocks from scratch with the basic JAX operators.
-
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh - A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.
-
Tutorial: image classification with JAX and Flax Linen by 8bitmp3 - Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
-
Plugging Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
-
Meta-Learning in 50 Lines of JAX by Eric Jang - Introduction to both JAX and Meta-Learning.
-
Normalizing Flows in 100 Lines of JAX by Eric Jang - Concise implementation of RealNVP.
-
Differentiable Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing path tracing.
-
Ensemble networks by Mat Kelcey - Ensemble nets are a method of representing an ensemble of models as one single logical model.
-
Out of distribution (OOD) detection by Mat Kelcey - Implements different methods for OOD detection.
-
Understanding Autodiff with JAX by Srihari Radhakrishna - Understand how autodiff works using JAX.
-
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke - Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
-
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey - Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
-
Evolving Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can power the next generation of scalable neuroevolution algorithms.
-
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz - Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
-
Deterministic ADVI in JAX by Martin Ingram - Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
-
Evolved channel selection by Mat Kelcey - Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
-
Introduction to JAX by Kevin Murphy - Colab that introduces various aspects of the language and applies them to simple ML problems.
-
Writing an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
-
How to add a progress bar to JAX scans and loops by Jeremie Coullon - Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callback
module.
-
Get started with JAX by Aleksa Gordić - A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.
-
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit - A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.
-
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar - A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.
-
Deep Learning tutorials with JAX+Flax by Phillip Lippe - A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.
-
Achieving 4000x Speedups with PureJaxRL - A blog post on how JAX can massively speedup RL training through vectorisation.
Books
-
Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.
Contributing
Contributions welcome! Read the contribution guidelines first.