Awesome JAX 
JAX
brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
Libraries
- Neural Network Libraries
-
Flax
- Centered on flexibility and clarity.
-
Haiku
- Focused on simplicity, created by the authors of Sonnet at DeepMind.
-
Objax
- Has an object oriented design similar to PyTorch.
-
Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
-
Trax
- “Batteries included” deep learning library focused on providing solutions for common workloads.
-
Jraph
- Lightweight graph neural network library.
-
Neural Tangents
- High-level API for specifying neural networks of both finite and infinite width.
-
HuggingFace
- Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
-
Equinox
- Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
-
Scenic
- A Jax Library for Computer Vision Research and Beyond.
-
Levanter
- Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
-
EasyLM
- LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
-
NumPyro
- Probabilistic programming based on the Pyro library.
-
Chex
- Utilities to write and test reliable JAX code.
-
Optax
- Gradient processing and optimization library.
-
RLax
- Library for implementing reinforcement learning agents.
-
JAX, M.D.
- Accelerated, differential molecular dynamics.
-
Coax
- Turn RL papers into code, the easy way.
-
Distrax
- Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
-
cvxpylayers
- Construct differentiable convex optimization layers.
-
TensorLy
- Tensor learning made simple.
-
NetKet
- Machine Learning toolbox for Quantum Physics.
-
Fortuna
- AWS library for Uncertainty Quantification in Deep Learning.
-
BlackJAX
- Library of samplers for JAX.
New Libraries
This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
-
FedJAX
- Federated learning in JAX, built on Optax and Haiku.
-
Equivariant MLP
- Construct equivariant neural network layers.
-
jax-resnet
- Implementations and checkpoints for ResNet variants in Flax.
-
Parallax
- Immutable Torch Modules for JAX.
-
jax-unirep
- Library implementing the UniRep model for protein machine learning applications.
-
jax-flows
- Normalizing flows in JAX.
-
sklearn-jax-kernels
- scikit-learn
kernel matrices using JAX.
-
jax-cosmo
- Differentiable cosmology library.
-
efax
- Exponential Families in JAX.
-
mpi4jax
- Combine MPI operations with your Jax code on CPUs and GPUs.
-
imax
- Image augmentations and transformations.
-
FlaxVision
- Flax version of TorchVision.
-
Oryx
- Probabilistic programming language based on program transformations.
-
Optimal Transport Tools
- Toolbox that bundles utilities to solve optimal transport problems.
-
delta PV
- A photovoltaic simulator with automatic differentation.
-
jaxlie
- Lie theory library for rigid body transformations and optimization.
-
BRAX
- Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
-
flaxmodels
- Pretrained models for Jax/Flax.
-
CR.Sparse
- XLA accelerated algorithms for sparse representations and compressive sensing.
-
exojax
- Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
-
JAXopt
- Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
-
PIX
- PIX is an image processing library in JAX, for JAX.
-
bayex
- Bayesian Optimization powered by JAX.
-
JaxDF
- Framework for differentiable simulators with arbitrary discretizations.
-
tree-math
- Convert functions that operate on arrays into functions that operate on PyTrees.
-
jax-models
- Implementations of research papers originally without code or code written with frameworks other than JAX.
-
PGMax
- A framework for building discrete Probabilistic Graphical Models (PGM’s) and running inference inference on them via JAX.
-
EvoJAX
- Hardware-Accelerated Neuroevolution
-
evosax
- JAX-Based Evolution Strategies
-
SymJAX
- Symbolic CPU/GPU/TPU programming.
-
mcx
- Express & compile probabilistic programs for performant inference.
-
Einshape
- DSL-based reshaping library for JAX and other frameworks.
-
ALX
- Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
-
Diffrax
- Numerical differential equation solvers in JAX.
-
tinygp
- The tiniest of Gaussian process libraries in JAX.
-
gymnax
- Reinforcement Learning Environments with the well-known gym API.
-
Mctx
- Monte Carlo tree search algorithms in native JAX.
-
KFAC-JAX
- Second Order Optimization with Approximate Curvature for NNs.
-
TF2JAX
- Convert functions/graphs to JAX functions.
-
jwave
- A library for differentiable acoustic simulations
-
GPJax
- Gaussian processes in JAX.
-
Jumanji
- A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
-
Eqxvision
- Equinox version of Torchvision.
-
JAXFit
- Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
-
econpizza
- Solve macroeconomic models with hetereogeneous agents using JAX.
-
SPU
- A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
-
jax-tqdm
- Add a tqdm progress bar to JAX scans and loops.
-
safejax
- Serialize JAX, Flax, Haiku, or Objax model params with 🤗safetensors
.
-
Kernex
- Differentiable stencil decorators in JAX.
-
MaxText
- A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
-
Pax
- A Jax-based machine learning framework for training large scale models.
-
Praxis
- The layer library for Pax with a goal to be usable by other JAX-based ML projects.
-
purejaxrl
- Vectorisable, end-to-end RL algorithms in JAX.
-
Lorax
- Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
-
SCICO
- Scientific computational imaging in JAX.
-
Spyx
- Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
-
BrainPy
- Brain Dynamics Programming in Python.
-
OTT-JAX
- Optimal transport tools in JAX.
-
QDax
- Quality Diversity optimization in Jax.
-
JAX Toolbox
- Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
-
Pgx
- Vectorized board game environments for RL with an AlphaZero example.
-
EasyDeL
- EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
-
XLB
- A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
-
dynamiqs
- High-performance and differentiable simulations of quantum systems with JAX.
Models and Projects
JAX
Flax
Haiku
Trax
-
Reformer
- Implementation of the Reformer (efficient transformer) architecture.
NumPyro
Videos
-
NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
-
Introduction to JAX - Simple neural network from scratch in JAX.
-
[JAX: Accelerated Machine Learning Research |
SciPy 2020 |
VanderPlas](https://youtu.be/z-WSrQDXkuM) - JAX’s core design, how it’s powering new research, and how you can start using it. |
-
Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.
-
[JAX: Accelerated machine-learning research via composable function transformations in Python |
NeurIPS 2019 |
Skye Wanderman-Milne](https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python) - JAX intro presentation in Program Transformations for Machine Learning workshop. |
-
[JAX on Cloud TPUs |
NeurIPS 2020 |
Skye Wanderman-Milne and James Bradbury](https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit) - Presentation of TPU host access with demo. |
-
[Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond |
NeurIPS 2020](https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond) - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers. |
-
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
-
JAX, Flax & Transformers 🤗
- 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
-
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - White paper describing an early version of JAX, detailing how computation is traced and compiled.
-
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
-
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - Uses JAX’s JIT and VMAP to achieve faster differentially private than existing libraries.
-
XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. - White paper describing the XLB library: benchmarks, validations, and more details about the library.
Tutorials and Blog Posts
-
Using JAX to accelerate our research by David Budden and Matteo Hessel - Describes the state of JAX and the JAX ecosystem at DeepMind.
-
Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural network building blocks from scratch with the basic JAX operators.
-
Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh - A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.
-
Tutorial: image classification with JAX and Flax Linen by 8bitmp3
- Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
-
Plugging Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
-
Meta-Learning in 50 Lines of JAX by Eric Jang - Introduction to both JAX and Meta-Learning.
-
Normalizing Flows in 100 Lines of JAX by Eric Jang - Concise implementation of RealNVP.
-
Differentiable Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing path tracing.
-
Ensemble networks by Mat Kelcey - Ensemble nets are a method of representing an ensemble of models as one single logical model.
-
Out of distribution (OOD) detection by Mat Kelcey - Implements different methods for OOD detection.
-
Understanding Autodiff with JAX by Srihari Radhakrishna - Understand how autodiff works using JAX.
-
From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke - Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
-
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey
- Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
-
Evolving Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can power the next generation of scalable neuroevolution algorithms.
-
Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz - Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
-
Deterministic ADVI in JAX by Martin Ingram - Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
-
Evolved channel selection by Mat Kelcey - Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
-
Introduction to JAX by Kevin Murphy - Colab that introduces various aspects of the language and applies them to simple ML problems.
-
Writing an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
-
How to add a progress bar to JAX scans and loops by Jeremie Coullon - Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callback
module.
-
Get started with JAX by Aleksa Gordić
- A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.
-
Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit - A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.
-
Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar - A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.
-
Deep Learning tutorials with JAX+Flax by Phillip Lippe - A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.
-
Achieving 4000x Speedups with PureJaxRL - A blog post on how JAX can massively speedup RL training through vectorisation.
Books
-
Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.
Contributing
Contributions welcome! Read the contribution guidelines first.