Asynchronous Model Average

In synchronous communication algorithms such as Gradient AllReduce, every worker needs to be in the same iteration in a lock-step style. When there is no straggler in the system, such synchronous algorithms are reasonably efficient, and gives deterministic training results that are easier to reason about. However, when there are stragglers in the system, with synchronous algorithms faster workers have to wait for the slowest worker in each iteration, which can dramatically harm the performance of the whole system. To deal with stragglers, we can use asynchronous algorithms where workers are not required to be synchronized. The Asynchronous Model Average algorithm provided by Bagua is one of such algorithms.


The Asynchronous Model Average algorithm can be described as follows:

Every worker maintains a local model . The -th worker maintains . Every worker runs two threads in parallel. The first thread does gradient computation (called computation thread) and the other one does communication (called communication thread). For each worker , a lock controls the access to its model.

The computation thread on the -th worker repeats the following steps:

  1. Acquire lock .
  2. Calculate a local gradient on a batch of input data.
  3. Release lock .
  4. Update the model with local gradient, .

The communication thread on the -th worker repeats the following steps:

  1. Acquire lock .
  2. Average local model with all other workers' models: .
  3. Release lock .

Every worker run the two threads independently and concurrently.

Example usage

First initialize the Bagua algorithm (see API documentation for more options):

from bagua.torch_api.algorithms import async_model_average
algorithm = async_model_average.AsyncModelAverageAlgorithm()

Then use the algorithm for the model

model = model.with_bagua([optimizer], algorithm)

Unlike running synchronous algorithms, you need to stop the communication thread explicitly when the training process is done (for example when you want to run test):


To resume the communication thread when you start training again, do:


A complete example of running the Asynchronous Model Average algorithm can be found in Bagua examples with --algorithm async command line argument.