Member-only story

DeepMind Haiku for JAX

New neural network library for JAX.

Przemek Chojecki
2 min readFeb 20, 2020

DeepMind has just released Haiku, a simple neural network library for JAX developed by authors of Sonnet, a neural network library for TensorFlow.

This is an alpha version and you’re welcome to jump in, and test it.

DeepMind Haiku

Recall that JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX’s pure function transformations.

So why Haiku for JAX?

It’s a library designed with specific goals in mind:

  • make managing model parameters and other model state simpler.
  • it is expected to compose well with other libraries and work with the rest of JAX.

--

--

Przemek Chojecki
Przemek Chojecki

Written by Przemek Chojecki

AI & crypto, PhD in mathematics, Forbes 30 under 30, former Oxford fellow.

No responses yet