Abstractions in JAX

Originally posted on LinkedIn.

I’m considering making a transition from TensorFlow to JAX and, so far, am loving how effectively the latter exposes low-level behavior while still providing useful abstractions.

For example, the code snippet1 below shows how one can perform gradient descent while utilising multiple devices:

  1. the gradients are computed on multiple devices
  2. they are synced across multiple devices and averaged
  3. the new parameters are computed by adjusting them in a direction opposite to gradient
@functools.partial(jax.pmap, axis_name="num_devices")
def update(params, xs, ys, learning_rate=0.005):
    # 1. Compute the gradients on the given minibatch
    # (individually on each device).
    grads = jax.grad(loss_fn)(params, xs, ys)

    # 2. Combine the gradients across all devices
    # (by taking their mean).
    grads = jax.lax.pmean(grads, axis_name="num_devices")

    # 3. Each device performs its own update, but since we
    # start with the same params and synchronise gradients,
    # the params stay in sync.
    new_params = jax.tree_map(
        lambda param, g: param - g * learning_rate,
        params,
        grads,
    )

    return new_params

  1. adapted from a tutorial by DeepMind’s Vladimir Mikulik and Roman Ring ↩︎