DeepMind Haiku for JAX
This is an alpha version and you’re welcome to jump in, and test it.
Recall that JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via
grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX’s pure function transformations.
So why Haiku for JAX?
It’s a library designed with specific goals in mind:
- make managing model parameters and other model state simpler.
- it is expected to compose well with other libraries and work with the rest of JAX.
DeepMind has reproduced a number of experiments in Haiku and JAX with relative ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.
In the first place Google researchers have build JAX, a domain-specific tracing JIT compiler, which generates high-performance accelerator code from pure Python and Numpy machine learning programs. It combines Autograd and XLA for high-performance machine learning research. At its core, it is an extensible system for transforming numerical functions.
Autograd helps JAX automatically differentiate native Python and Numpy code. It can handle a large subset of Python features such as loops, branches, recursion, and closures. It comes with support for reverse-mode (backpropagation) and forward-mode differentiation, and these two can be composed arbitrarily in any order.
XLA or Accelerated Linear Algebra is a linear algebra compiler used for optimizing TensorFlow computations. To run the NumPy programs on GPUs and TPUs, JAX uses XLA. The library calls are compiled and executed just-in-time. JAX also allows compiling your own Python functions just-in-time into XLA-optimized kernels using a one-function API, jit.