Originally posted on LinkedIn.
For example, the code snippet1 below shows how one can perform gradient descent while utilising multiple devices:
- the gradients are computed on multiple devices
- they are synced across multiple devices and averaged
- the new parameters are computed by adjusting them in a direction opposite to gradient
@functools.partial(jax.pmap, axis_name="num_devices") def update(params, xs, ys, learning_rate=0.005): # 1. Compute the gradients on the given minibatch # (individually on each device). grads = jax.grad(loss_fn)(params, xs, ys) # 2. Combine the gradients across all devices # (by taking their mean). grads = jax.lax.pmean(grads, axis_name="num_devices") # 3. Each device performs its own update, but since we # start with the same params and synchronise gradients, # the params stay in sync. new_params = jax.tree_map( lambda param, g: param - g * learning_rate, params, grads, ) return new_params