[source code analysis] PyTorch distributed (11) -- Construction of distributeddataparallel Reducer and Join operation

[source code analysis] PyTorch distributed (11) -- Construction of distributeddataparallel Reducer and Join operation

0x00 summary

Because the previous article has done correlation analysis around various member variables related to Reducer, this paper begins to do dynamic logic analysis. The purpose is to connect the previous articles and set the basis for the subsequent analysis of forward propagation and back propagation.

Other articles in this series are as follows:

Automatic differentiation of deep learning tools (1)

Automatic differentiation of deep learning tools (2)

[Source code analysis] automatic differentiation of deep learning tools (3) -- example interpretation

[Source code analysis] how PyTorch implements forward propagation (1) -- basic class (I)

[Source code analysis] how PyTorch implements forward propagation (2) -- basic classes (Part 2)

[Source code analysis] how PyTorch implements forward propagation (3) -- specific implementation

[Source code analysis] how pytoch implements backward propagation (1) -- call engine

[Source code analysis] how pytoch implements backward propagation (2) -- engine static structure

[Source code analysis] how pytoch implements backward propagation (3) -- engine dynamic logic

[Source code analysis] how PyTorch implements backward propagation (4) -- specific algorithm

[Source code analysis] PyTorch distributed (1) -- history and overview

[Source code analysis] PyTorch distributed (2) -- dataparallel (Part 1)

[Source code analysis] PyTorch distributed (3) -- dataparallel (Part 2)

[Source code analysis] PyTorch distributed (4) -- basic concept of distributed application

[Source code analysis] PyTorch distributed (5) -- overview of distributeddataparallel & how to use

[Source code analysis] PyTorch distributed (6) -- distributeddataparallel -- initialization & store

[Source code analysis] PyTorch distributed (7) -- process group of distributeddataparallel

[Source code analysis] PyTorch distributed (8) -- distributed dataparallel

[Source code analysis] PyTorch distributed (9) -- initialization of distributeddataparallel

[Source code analysis] PyTorch distributed (10) -- Reducer static architecture of distributeddataparallel

0x01 introduction

For better analysis, we still need to see how to call.

1.1 call

The creation code of Reducer is as follows, which is in_ ddp_init_helper.

        # Note: reverse list of buckets because we want to approximate the
        # order in which their gradients are produced, and assume they
        # are used in the forward pass in the order they are defined.
        self.reducer = dist.Reducer(
            parameters, # parameters[0] is a list of tensors
            list(reversed(bucket_indices)), # Bucket information

1.2 parameter description

The parameters called are as follows. parameters[0] is the parameters of the model on rank 0. You can see that only [0] elements are meaningful. The original [0] itself includes 20 elements:

parameters = {list: 1} 
0 = {list: 4}           
 0 = {Parameter: 10} Parameter containing:\ntensor([[-4.0381e-02,  3.8828e-02, 1  )   
 1 = {Parameter: 10} Parameter containing:\ntensor([-0.0438, -0.2033,  0.2771,  0.0721,  ) 
 2 = {Parameter: 5} Parameter containing:\ntensor([[-0.0094, -0.1319,  0.0713,  0.3155,  )
 3 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )
 20 = {Parameter: 5} Parameter containing:\ntensor([-0.0008,  0.0582, -0.1245, -0.2538, )                                                   
 __len__ = {int} 20
__len__ = {int} 1

bucket_ Examples of indices are as follows:

For tensor indexes, we give all tensors an index, which increases from 0 to tensors.size(). If the parameters of the model have a total of 20 tensors, then the tensor index is divided into six buckets from 0 to 19. Among the six buckets, each tensor index is unique and does not repeat.

|                                                                       |
|  <tensor index 0, tensor index 1, tensor index 2, tensor index 3>     |
|                                                                       |
|                                                                       |
|  <tensor index 4, tensor index 5, tensor 6>                           |
|                                                                       |
|                                                                       |
|  ......                                                               |
|                                                                       |
|                                                                       |
|  <tensor index 16, tensor index 17, tensor index 18, tensor index 19> |
|                                                                       |

Next, let's see how to initialize Reducer.

0x02 Reducer initialization

The code is located in torch/lib/c10d/reducer.h and torch/lib/c10d/reducer.cpp

2.1 constructor

The specific logic is as follows:

  • See if this module is a multi device module. Specifically, traverse the tensor to obtain the tensor device, and insert the device into a set structure. If there is more than one device in the set, it is a multi device module
  • If expect_ sparse_ If gradients is not set, expect_sparse_gradients_ Initialize to false.
  • Call initialize_buckets initializes buckets and allocates parameters to buckets in reverse order as much as possible, so that communication by bucket can improve efficiency. Subsequently, the bucket may be reinitialized again at run time.
  • Add grad to each parameter_ Calculators, which are responsible for gradient synchronization when backward.
    • Because these variables are leaf tensors of autograd graphs, their grad_fn is set to gradient calculation function.
    • The Reducer saves pointers to these functions so that the Reducer can know whether they are used in autograd propagation. If not, set the gradient tensors of these functions to the protocol ready state.
    • Traverse tensors and generate a variable index of type VariableIndex for each tensor.
    • Get the grad of Variable::AutogradMeta_ accumulator_, That is, the gradient accumulator used to accumulate leaf variables.
    • Put the reducer's autograd_hook functions are added to each grad_accumulator_ Among them, the variable index is the parameter of hook. This hook is hung on the autograd graph and is responsible for gradient synchronization during backward. grad_ After the calculator is executed, autograd_hook will run.
  • gradAccToVariableMap_ Saved grad_ The corresponding relationship between calculator & Index (the corresponding relationship between function pointer and parameter tensor), so it is convenient to search for unused parameters in the autograd graph in the future.
  • Initialize backward_stats_.
  • Call initialize_local_used_map initializes various unused map s.
// The constructor takes a list of variables for every model replica.
// The bucket assignment for this reducer is specified as a list of
// buckets, each of which is specified as a list of indices into the
// variables list for **a single replica** (i.e. `variables[0]`).
    std::vector<std::vector<at::Tensor>> replicas, // tensor
    std::vector<std::vector<size_t>> bucket_indices, // Bucket information
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    std::vector<std::vector<bool>> expect_sparse_gradients,
    int64_t bucket_bytes_cap,
    bool find_unused_parameters,
    bool gradient_as_bucket_view,
    std::unordered_map<size_t, std::string> paramNames)
    : replicas_(std::move(replicas)),
      param_names_(std::move(paramNames)) {

  // Check whether the module is multi_device_module
  // See if this module is a multi device module
    std::set<int> unique_devices;
    for (const auto& v : replicas_[0]) { // Ergodic tensor
      auto device_idx = int(v.device().index()); // Device for obtaining tensor
      if (unique_devices.find(device_idx) == unique_devices.end()) {
        unique_devices.insert(device_idx); // Insert the device into a set structure
        if (unique_devices.size() > 1) { // If there is more than one device in the set, it is multiple devices
          is_multi_device_module_ = true; 

  // If `expect_sparse_gradients` is not specified, initialize it such that
  // we do not expect sparse gradients for any parameter.
  if (expect_sparse_gradients_.empty()) {
    expect_sparse_gradients_ = std::vector<std::vector<bool>>(
        replicas_.size(), std::vector<bool>(replicas_[0].size(), false));

  // Initialize variable bucketing.
  // This can be reinitialized later after capturing runtime information.
    std::lock_guard<std::mutex> lock(mutex_);
    initialize_buckets(std::move(bucket_indices)); //Initialization bucket

  // All variables are expected to have their `grad_fn` set to the gradient
  // accumulation function (since they are leafs in the autograd graph).
  // We store pointers to these functions such that we can check if they are
  // used in an autograd pass. If they are not, we know their grad tensors
  // can be marked as ready for reduction.
    const auto replica_count = replicas_.size();
    for (size_t replica_index = 0; replica_index < replica_count; // Only replicas_[0] makes sense
         replica_index++) {
      const auto variable_count = replicas_[replica_index].size(); //Tensor number
      grad_accumulators_[replica_index].resize(variable_count); // Here, grad_accumulators_ Allocate memory
      for (size_t variable_index = 0; variable_index < variable_count;
           variable_index++) { // Ergodic tensor, variable_index is the index of tensor
        auto& variable = replicas_[replica_index][variable_index]; //Get the specific tensor
        const auto index = VariableIndex(replica_index, variable_index); //Each tensor generates a VariableIndex

        // The gradient accumulator function is lazily initialized once.
        // Therefore we can use its presence in the autograd graph as
        // evidence that the parameter has participated in an iteration.
        auto grad_accumulator =
            torch::autograd::impl::grad_accumulator(variable); // Get the grad of Variable::AutogradMeta_ accumulator_, That is, a gradient accumulator for accumulating leaf variables

#ifndef _WIN32
        using torch::distributed::autograd::ThreadLocalDistAutogradContext;
        // Hook to execute after the gradient accumulator has executed.
            // Add a hook to the accumulator, which is hung on the autograd graph and is responsible for gradient synchronization during backward.
            // grad_ After the calculator is executed, autograd_hook will run
                    [=](const torch::autograd::variable_list& outputs,
                        const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
                      this->autograd_hook(index); // Put the reducer's autograd_ Add the hook function
                      return outputs;

        // Map raw function pointer to replica index and parameter index.
        // This is used later on when the autograd graph is traversed
        // to check for parameters for which no gradient is computed, if
        // find_unused_parameters=True.
        // Note that the mapping of gradient accumulator to variable should be
        // one to one as we deduplicate shared parameters before constructing
        // Reducer.
        // gradAccToVariableMap_  Saved grad_ The corresponding relationship between calculator & Index (the corresponding relationship between function pointer and parameter tensor), so it is convenient to search for unused parameters in the autograd graph in the future
        if (find_unused_parameters_) {
          gradAccToVariableMap_[grad_accumulator.get()] = index;

        numGradHooksTriggeredMap_[index] = 0;

        // The gradient accumulator is stored as weak_ptr in the autograd
        // metadata of the variable, so we have to keep it alive here for
        // the raw pointer to be valid.
            grad_accumulators_[replica_index][variable_index] == nullptr,
                "Reducer tried to register duplicate grad accumulator for replica ",
                " variable ",
        grad_accumulators_[replica_index][variable_index] =

  // Initialize backward stats vector.
    const auto replica_count = replicas_.size();
    const auto variable_count = replicas_[0].size();
        [=](std::vector<int64_t>& v) { v.resize(variable_count); });

  // See Note [Skip allreducing local_used_maps_dev]
  if (find_unused_parameters_) {

Next, we analyze each part in detail.

2.2 initialization bucket

initialize_ The buckets method is used to initialize buckets. The specific logic is to add a model copy for each bucket and a tensor list for each model copy:

  • Setting up RPC with distributed context_ context_.

    • If initialize is called inside the DDP constructor_ Bucket, it does not matter whether the rpc context pointer (rpc context ptr) is null, because grad will not change.
    • If initialize is called during a training cycle_ Bucket, for example, in rebuild_ Inside the bucket, because grad may change and point to the bucket_view, it needs to check whether the rpc context ptr is null.
    • If the rpc context ptr is null, change variable.grad(), otherwise, change the gradient in the rpc context.
  • Empty buckets_ And variable_locators_.

  • Reset variable_locators_ So that each variable has a bucket index.

  • Use the following to get the number of all buckets and the number of copies in each bucket: bucket_count = bucket_indices.size(); replica_count = replicas_.size();

  • Increment from 0 to bucket_count, initialize buckets one by one.

    • Generate a Bucket bucket
    • If bucket_ Indications [bucket_index]. Size() = = 1, indicating that the bucket expects a single spark gradient, then set bucket.expect_sparse_gradient = true.
    • Increment from 0 to replica_count, initialize BucketReplica one by one.
      • Generate a bucket replica
      • If the bucket expects a single spark gradient, then
        • Using bucket_ Indexes [bucket_index]. Front() takes out the first element of the vector and sets it to variable_index.
        • Using variable_index gets the corresponding variable in the copy.
        • Set the variable list of replica. The code is replica.variables = {variable}. This replica only includes one variable.
      • Otherwise, the description is deny gradient, then
        • Traverse the variables of the bucket, that is, use replicas_[replica_index][variable_index] gets variable.
        • Set the device and data type of variable
        • Set the variables for the replica. The code is: replica.variables.push_back(variable).
        • Set some meta information about variable s in replica, which is related to flat contents. For example, offsets stores the offsets of various tensors in flat bucket contents.
        • Allocate memory to relica.contents
        • Using initialize_bucket_views(replica, replica.contents) initializes cotnents and views.
        • Using bucket.replicas.push_back(std::move(replica)) adds the replica to the bucket.
    • Traverse the variable s in the bucket. The code is bucket_indices[bucket_index].
      • Set Reducer.variable_locators_, In this way, the Reducer knows how to determine a variable in the bucket. bucket_index is buckets_ The location of the list, indicating buckets_ A bucket above. intra_bucket_index is the variable index of the vector field in bucket replica.
    • Set the bucket variable, bucket.variable_indices = std::move(bucket_indices[bucket_index]);
    • Use buckets_ push_ Back (STD:: move (bucket)) adds the bucket to the Reducer.

The specific code is:

void Reducer::initialize_buckets(
    std::vector<std::vector<size_t>> bucket_indices) {
  // If initialize_buckets is called inside DDP constructor, then
  // it does not matter rpc context ptr is nullptr or not, as grad
  // will not be mutated.
  // If initialize_buckets is called during training loop, e.g, inside
  // rebuild_buckets(), since grad could be mutated and be pointed to
  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
  // mutate grad in rpc context.
#ifndef _WIN32
  using torch::distributed::autograd::ThreadLocalDistAutogradContext;

  // This shouldn't be called if we're expecting autograd hooks to fire.
      "`initialize_buckets` must NOT be called during autograd execution.");

  // Clear current bucket assignment.

  // Ensure we have a bucket index for every variable.

  // Iterate over buckets.
  const auto bucket_count = bucket_indices.size();
  const auto replica_count = replicas_.size();
  // Increment from 0 to bucket_count
  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
    Bucket bucket; // Generate a bucket

    // TODO(@pietern): Validate indices.
    // Must be non-empty, unique, and unique across buckets.
        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");

    // Variables that expect sparse gradients must have their own bucket.
    if (bucket_indices[bucket_index].size() == 1) {
      // This indicates that the bucket expects a single spark gradient
      const auto variable_index = bucket_indices[bucket_index].front();
      bucket.expect_sparse_gradient =
    } else {
      for (const auto variable_index : bucket_indices[bucket_index]) {
            "Buckets with more than one variable cannot include variables ",
            "that expect a sparse gradient.");

    // Iterate over model replicas. Increment from 0 to replica_count, the number of model copies traversed. The same setting should be made for each model copy
    for (size_t replica_index = 0; replica_index < replica_count;
         replica_index++) {
      BucketReplica replica; // Make a copy

      if (bucket.expect_sparse_gradient) {
        // This indicates that the bucket expects a single spark gradient
        const auto variable_index = bucket_indices[bucket_index].front(); // Get the index of the tensor
        const auto& variable = replicas_[replica_index][variable_index]; // Get tensor
        TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
        replica.variables = {variable}; // This copy contains only one variable
      } else {
        at::TensorOptions options;
        // The start index of the variable in the flattened tensor.
        size_t offset = 0;

        // Reserve enough space for the per-variable fields stored in bucket
        // replica for efficiency.
        const size_t num_variables = bucket_indices[bucket_index].size();

        // Iterate over bucket variables.
        for (const auto variable_index : bucket_indices[bucket_index]) { //Traverse variable s in bucket
              variable_index < replicas_[replica_index].size(),
              "Out of range variable index specified.");
          const auto& variable = replicas_[replica_index][variable_index];
          if (!options.has_device()) {
            options = options.device(variable.device());
          } else {
                variable.device() == options.device(),
                "All parameters in a bucket must be ",
                "placed on the same device.");
          if (!options.has_dtype()) {
            options = options.dtype(variable.dtype());
          } else {
                variable.dtype() == options.dtype(),
                "All parameters in a bucket must have the same dtype.");
          const auto length = variable.numel();
          // Set its variables for the replica
          replica.variables.push_back(variable); // A new variable is added here, so you can finally know the number of variables in the bucket
          // Set some meta information about variable s in replica
          offset += length;

        // Allocate bucket contents tensor.
        replica.contents = at::empty({static_cast<long>(offset)}, options);

        initialize_bucket_views(replica, replica.contents); // Initialize cotents and views

      // Add bucket replica to enclosing bucket.
      bucket.replicas.push_back(std::move(replica)); // Adds a new copy to the list of copies of the bucket

    // Map participating variables to this bucket.
    // This is identical across replicas so we only need to do this once.
    size_t intra_bucket_index = 0;
    for (const auto variable_index : bucket_indices[bucket_index]) { // Traverse variable s in bucket
          variable_index < variable_locators_.size(),
          "Out of range variable index specified.");
      variable_locators_[variable_index] = // In this way, the Reducer knows how to determine a variable in the bucket
          VariableLocator(bucket_index, intra_bucket_index++);
    bucket.variable_indices = std::move(bucket_indices[bucket_index]);

    buckets_.push_back(std::move(bucket)); // Insert the bucket into the Reducer

2.3 initialization view

initialize_bucket_views here is to set the contents and views of Replica.

// (see Note:  "Gradient Layout Contract" in initialize_buckets).
void Reducer::initialize_bucket_views(
    Reducer::BucketReplica& replica,
    at::Tensor& contents) {
  for (size_t i = 0; i < replica.variables.size(); i++) {
    auto& v = replica.variables[i];
    const auto offset = replica.offsets[i];
    const auto length = replica.lengths[i];
    if (v.is_non_overlapping_and_dense()) { // Tensor of density type
      // If the param's memory is dense, match its layout, anticipating
      // the autograd engine (AccumulateGrad) will also create gradients
      // matching its layout.
      replica.bucket_views_in.push_back( // replica.bucket_views_in is full of views
          contents.as_strided(v.sizes(), v.strides(), offset));
    } else { // Tensor of Sparse type
      // Fall back to a C-style contiguous view, again anticipating
      // AccumulateGrad will do the same when stashing grads for non-dense
      // params.
      replica.bucket_views_in.push_back( // replica.bucket_views_in is full of views
          contents.narrow(0, offset, length).view(v.sizes()));
    // By default `bucket_views_out` and `bucket_views_in` are
    // essentially the same thing.
    replica.bucket_views_out = replica.bucket_views_in; // out is also a view

    // If gradient_as_bucket_view_ is set as true, then there are two cases to
    // handle: initialize_bucket_views could be called inside initialize_buckets
    // when rebuild_buckets, if grad has already been defined/calculated in
    // previous iteration, old grad needs to be copied into new bucket_view and
    // let grad point to the new bucket_view, initialize_bucket_views could also
    // be called inside initialize_buckets during construction. Grads are not
    // defined during construction time, in this case, do not let grad point to
    // bucket_view, because grads should be kept as being undefined for globally
    // unused parameters.
    if (gradient_as_bucket_view_) {
      auto& bucket_view = replica.bucket_views_in.back();
      runGradCallbackForVariable(v, [&](auto& grad) {
        if (grad.defined() && !grad.is_alias_of(bucket_view)) {
          grad = bucket_view; // The gradient has been modified and needs to be written back
          // The grad is modefied and needs to be written back.
          return true;
        // The grad is not modified and does not need to be written back.
        return false; // There is no need to write back because it has not been modified

2.3.1 bucket replica member variables

Let's first recall several member variables of BucketReplica.

  • at::Tensor contents: the result of flattening the contents of the bucket, that is, the result after Flattened (1 dimensional).
  • std::vector<at::Tensor> bucket_ views_ In: provides a method to view specific gradients in contents from the perspective of input.
  • std::vector<at::Tensor> bucket_ views_ Out: provides a method to view specific gradients in contents from the perspective of input.

About STD:: vector < at:: tensor > bucket_ views_ In and STD:: vector < at:: tensor > bucket_ views_ Further description of out:

  • These two variables provide methods to manipulate specific gradients in contents, or they provide views that can manipulate the gradients of each tensor in contents. Users use these two variables as entry points to move the data of each gradient in and out of the content.
  • In PyTorch, view refers to creating something convenient to view. The view shares memory with the original data. It just arranges the original data, directly displays some of its contents, or displays it after reordering.

Several PyTorch functions also need to be described.

  • as_ Striped: create a view according to the existing tensor and the given step size (the type is still tensor). Note that the result here is a view, so this tensor still shares memory with the original tensor.
  • narrow: returns a new tensor, which is a reduced version of the original tensor, but this tensor still shares memory with the original tensor.

Bucket replica logic is shown in the following figure:

| BucketReplica                            |
|                                          |
|       vector<Tensor> bucket_views_in +--------------------+
|                                          |                |
|                                          |                |
|       vector<Tensor> bucket_views_out +--------------+    |
|                                          |           |    |
|                                          |           |    |
|                                          |           v    v
|                                          |     +-----+----+--------------------------+
|       Tensor contents  +---------------------> |Flattened (Tensor1, Tensor2, Tensor3)|
|                                          |     +-------------------------------------+
|                                          |
|                                          |
|       vector<Tensor> variables  +------------>  [Tensor1,Tensor2,Tensor3]
|                                          |
|                                          |
|                                          |

2.3.2 calling

How to call? If gradient_as_bucket_view_ If it is set to true, two situations need to be handled:

  • rebuild_buckets can be initialized_ Call initialize in bucket_ bucket_view, if grad has been defined / calculated in the last iteration, you need to copy the old grad to the new bucket_view and point grad to the new bucket_view,
  • During construction, you can also initialize_ Calling initialize_ in bucket bucket_ views. Gradients are not defined during construction. In this case, do not let the gradients point to buckets_ View, because for parameters not used globally, the gradient should remain undefined.

2.4 initializing local variables

initialize_local_used_map here is the initialization local_used_maps_, Let's recall the content of the paper_ used_ maps_ It is used to find global unused parameters:

The gradient of global unused parameters should remain unchanged in the forward and backward process. Detecting unused parameters requires global information, because in a DDP process, a parameter may not exist in one operation, but may participate in training in the same iteration of another process. Therefore, DDP maintains locally unused parameter information in the bitmap and starts additional AllReduce to collect the global bitmap. Since the bitmap is much smaller than the tensor size, all parameters in the model share the same bitmap instead of creating per bucket bitmaps. The bitmap is located on the CPU to avoid starting a dedicated CUDA kernel for each update. However, some ProcessGroup backend may not be able to run AllReduce on the CPU tensor. For example, ProcessGroupNCCL only supports CUDA tensors. In addition, since DDP should work with any custom ProcessGroup backend, it cannot assume that all backend supports CPU tensor. To solve this problem, DDP maintains another bitmap on the same device as the first model parameter, and calls a non blocking copy to move the CPU bitmap to the device bitmap for collective communication.

The specific codes are as follows:

void Reducer::initialize_local_used_map() {
  const auto replica_count = replicas_.size();
  const auto variable_count = replicas_[0].size();

  for (size_t i = 0; i < replica_count; i++) {
    at::TensorOptions options;
    options = options.dtype(at::kInt);

    // Deliberately don't pin the memory even if local_used_maps_dev_ will
    // be cuda. See Note [local_used_maps_ -> local_used_maps_dev copying]
    local_used_maps_[i] =
        at::zeros({static_cast<long>(variable_count)}, options);

    // This tensor needs to be on the same device as replica because backend
    // such as NCCL may not support CPU tensors, and hence it might not work
    // if we always put it on CPU.
    options = options.device(replicas_[i][0].device());
    local_used_maps_dev_[i] =
        at::empty({static_cast<long>(variable_count)}, options);

The initialization process is as follows:

                  rpc_context_ = ThreadLocalDistAutogradContext
                  buckets_ & variable_locators_ (clear & resize)
+----------------------->  from 0 ~ bucket_count :  +--------------------------->
|                                                                                +
|                                                                                |
|      +-------------------------------------------------------------------+     |
|      | init Bucket          set bucket_indices                           |     |
|      |                            +                                      |     |
|      |                            |                                      |     |
|      |                            |                                      |     |
|      |                            v                                      |     |
|      |   ^ +------------> from 0 ~ replica_count : +----------------->   |     |
|      |   |                                                           |   |     |
|      |   |  +---------------------------------------------------+    |   |     |
|      |   |  | init BucketReplica                                |    |   |     |
|      |   |  |                                                   |    |   |     |
<----+ |   +--+                                                   | <--+   | <---+
       |      |    bucket.replicas.push_back(std::move(replica))  |        |
       |      |                                                   |        |
       |      +----------------------+----------------------------+        |
       |                             |                                     |
       |                             |                                     |
       |                             v                                     |
       |             buckets_.push_back(std::move(bucket))                 |
       |                             +                                     |

The reducers obtained are roughly as follows. It should be noted that there is only one bucket in BucketReplica:

            +----------------------------------------+                 +------------------+
            |tensor index 4, tensor index 5, tensor 6| <------+        | index 2, index 3 |
            +----------------------------------------+        |        +--------------+---+
                                                              |                       ^
                                                              |                       |
+---------------------------+   +---------------------------------------------------------+
| Reducer                   |   | +----------------------------------+     +------------+ |
|                           |   | |Bucket                     |      |     |Bucket    | | |
|                           |   | |                           +      |     |          | | |
| vector<Bucket> buckets_ +---> | | vector<size_t> variable_indices  |     | indices ++ | |
|                           |   | |                                  |     |            | |
|                           |   | |  vector<BucketReplica> replicas  | ... | replicas   | |
|                           |   | |                         +        |     |   +        | |
|                           |   | |                         |        |     |   |        | |
|                           |   | +----------------------------------+     +------------+ |
|                           |   |                           |                  |          |
+---------------------------+   +---------------------------------------------------------+
                                                            |                  |
                                                            |                  |
                                                            v                  v
                          +---------------------------------------+   +-------------------+
                          |  +----------------------------------+ |   | +---------------+ |
                          |  | BucketReplica                    | |   | | BucketReplica | |
                          |  |                                  | |   | |               | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_in  | |   | |   views_in    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> bucket_views_out | |   | |   views_out   | |
                          |  |                                  | |   | |               | |
                          |  |  Tensor contents                 | |   | |   contents    | |
                          |  |                                  | |   | |               | |
                          |  |  vector<Tensor> variables        | |   | |   variables   | |
                          |  |                     +            | |   | |      +        | |
                          |  +----------------------------------+ |   | +---------------+ |
                          +---------------------------------------+   +-------------------+
                                                   |                           |
                                                   |                           |
                                                   v                           v
                                   +---------------+------------+    +---------+----------+
                                   |Tensor 4, Tensor 5, Tensor 6|    | Tensor 2, Tensor 3 |
                                   +----------------------------+    +--------------------+

0x03 static diagram

3.1 reasons

Although PyTorch is a dynamic graph, the user can clearly let DDP know that the training graph is static. It can be set in the following cases:

  1. Used and unused parameter sets remain unchanged throughout the training cycle. In this case, will the user find_ unsued_ Setting parameters to true is not important.

  2. The training mode of graphics will not change during the whole training cycle (which means that there is no control flow dependent on iteration). When the graph is set to static, DDP will support case s that were not previously supported, such as:

    1. Reentrant back propagation.
    2. Multiple activation checkpointing.
    3. activation checkpointing and find_unused_parameters = true.
    4. Not all output tensors are used for loss calculation..
    5. There is a model parameter outside the forward function.
    6. When find_ unsued_ When parameters = true or there are unused parameters, performance may be improved because DDP does not search the network to check unused parameters within each iteration.

3.2 use

_ set_static_graph can configure static diagrams. This API should be constructed after DistributedDataParallel and is called before the training cycle starts. Also, all rank should be called in the same way. For example:

ddp_model = DistributedDataParallel(model)
for i in range(n):

_ set_ static_ The graph code is:

def _set_static_graph(self):
    Users can explicitly let DDP know the trained graph is static,
    when 1) the set of used and unused parameters will not change
    during the whole training loop; in this case, it does not matter
    whether users set find_unsued_parameters = true or not.
    2) how the graph is trained will not change during the whole training
    loop (meaning there is no control flow depending on iterations).
    When graph is set to be static, DDP will support cases that can not
    be supported in the past: 1) reentrant backwards
    2) activation checkpointing multiple times 3)
    activation checkpointing with find_unused_parameters = true.
    4) not all output tensors are used in loss calculation.
    5) there is model parameter that is outside of forward function.
    6) potentially improve performance when find_unsued_parameters = true
    or there are unused parameters, as DDP will not search graph in each
    iteraton to detect unused parameters when static_graph is set to be True.

    This API should be called after DistributedDataParallel construction, and
    before training loops starts. Also it should be called in the same way for
    all ranks. For example:
        ddp_model = DistributedDataParallel(model)
        for i in range(n):
    self.static_graph = True
    self.reducer._set_static_graph() # Call Reducer for configuration
    if self.find_unused_parameters:
            "You passed find_unused_parameters=true to DistributedDataParallel, "
            "`_set_static_graph` will detect unused parameters automatically, so "
            "you do not need to set find_unused_parameters=true, just be sure these "
            "unused parameters will not change during training loop while calling "

3.2 Reducer

Reducer can only generate static graphs after the first iteration, because PyTorch is still dynamic after all, and you have to take one step to generate dynamically anyway.

void Reducer::set_static_graph() {
  std::lock_guard<std::mutex> lock(mutex_);
      num_iterations_ == 0,
      "set_static_graph() should be called before training loop starts "
      "and after DistributedDataParallel is constructed.");
  static_graph_ = true;
  // when static_graph_ is set as true, always initialize_local_used_map
  // and detect the global unused parameters in the first iteration.

0x04 rebuild bucket

4.1 why rebuild

Since PyTorch is a calculation diagram generated dynamically, it is necessary to reconstruct the bucket accordingly. However, the static graph can only be rebuilt after the first iteration if find is set_ unused_ parameters_, No reconstruction.

  // Returns true if we should rebuild buckets, else false. We only rebuild
  // buckets once after the first iteration and never rebuild them if
  // find_unused_parameters_.
  inline bool should_rebuild_buckets() const {
    return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;

4.2 preparation for reconstruction

Let's first look at some preparations before reconstruction.

push_rebuilt_params is to insert a reconstruction parameter list.

void Reducer::push_rebuilt_params(const VariableIndex& index) {

Second, push_ rebuilt_ params_ for_ all_ Indexes will traverse each replica and set each variable in the replica.

void Reducer::push_rebuilt_params_for_all_indices() {
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
  const auto replica_count = replicas_.size();
  for (size_t replica_index = 0; replica_index < replica_count;
       ++replica_index) {
    const auto variable_count = replicas_[replica_index].size();
    for (size_t variable_index = 0; variable_index < variable_count;
         ++variable_index) {
      const auto index = VariableIndex(replica_index, variable_index);

4.3 reconstruction

Let's look at the reconstruction mechanism.

DDP uses rebuild according to the time when the tensor receives the gradient in backward propagation_ params_ And rebuild_ param_ indices_ To rebuild the bucket.

rebuild_ The buckets function makes broadcast communication calls and can overlap with the next forward() call, so it can be asynchronous.

  • In find_ unused_ When parameters = true, rebuilding a bucket is an asynchronous operation, because we can rebuild a bucket many times. The subgraph is trained, and the parameter index order may change more frequently.
  • For find_ unused_ When parameters = false, the bucket is rebuilt only once, and the performance cost is negligible. If the bucket has been rebuilt, rebuild_buckets returns true.
bool Reducer::rebuild_buckets() {
  // Ensure reduction for previous backwards pass is finished. If user's model
  // has unused parameters for example, this will raise an error recommending to
  // run with find_unused_parameters=True, instead of the size mismatch
  // exception below.
  std::lock_guard<std::mutex> lock(mutex_);
  if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
    return false;

  std::vector<std::vector<size_t>> rebuilt_bucket_indices;
  std::vector<size_t> bucket_size_limits;
  rebuilt_bucket_indices = compute_bucket_assignment_by_size(

  // For rebuilt bucket indices, it needs to be synced across all ranks.
  // Broadcast the newly rebuilt bucket indices from rank 0 in default.
  // After syncing up rebuilt bucket indices, initialize buckets for reducer.

  has_rebuilt_bucket_ = true; // Rebuild only once

  return true;

4.4 when to set reconstruction

Reconstruction can only be set in the following cases:

  1. Rebuild bucket for the first time

  2. static_graph_ is true or find_unused_parameters_ is false

  3. This back propagation process requires allreduce to be run.

Here, we just dump the tensor and its parameter index to rebuild based on the gradient arrival order_ params_ And rebuild_ param_ indices_. Then in finalize_ When backward() ends, it will be based on rebuild_ params_ And rebuild_ param_ indices_ Rebuild the bucket, then broadcast and initialize the bucket.

In addition, we only need to dump a copy of the tensor and parameter index.

Mark_ variable_ Take ready as an example, where push will be called_ rebuilt_ Params (index) to insert the list.

void Reducer::mark_variable_ready(VariableIndex index) {
  // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
  // static_graph_ is true or find_unused_parameters_ is false,
  // 3) this backward pass needs to run allreduce.
  // Here, we just dump tensors and their parameter indices into
  // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
  // order, and then at the end of finalize_backward(), buckets will be
  // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
  // will be broadcasted and initialized. Also we only need to dump tensors
  // and parameter indices of one replica.
  if (should_rebuild_buckets()) {
    push_rebuilt_params(index); // Insert list

  const auto replica_index = index.replica_index;
  const auto variable_index = index.variable_index;

  if (replica_index == 0) {
  backward_stats_[replica_index][variable_index] =
      current_time_in_nanos() - cpu_timer_.backward_compute_start_time;

  // Any time we mark a variable ready (be it in line due to unused parameters,
  // or via an autograd hook), we require a call to the finalize function. If
  // this doesn't happen before the next iteration (or call to
  // `prepare_for_backwards`), we know something is wrong.
  require_finalize_ = true;

  const auto& bucket_index = variable_locators_[variable_index];
  auto& bucket = buckets_[bucket_index.bucket_index];
  auto& replica = bucket.replicas[replica_index];


  if (bucket.expect_sparse_gradient) {
  } else {

  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
  // When using CPU tensors we don't need to do this.
  // // Record event so that we can wait for all of them.
  // auto& event = replica.events[bucket_index.intra_bucket_index];
  // event.record();

  // Check if this was the final gradient for this bucket.
  if (--replica.pending == 0) {
    // Kick off reduction if all replicas for this bucket are ready.
    if (--bucket.pending == 0) {

  // Run finalizer function and kick off reduction for local_used_maps once the
  // final bucket was marked ready.
  if (next_bucket_ == buckets_.size()) {

    if (dynamic_graph_find_unused()) {

    // The autograd engine uses the default stream when running callbacks, so we
    // pass in the current CUDA stream in case it is not the default.
    const c10::Stream currentStream = get_current_stream();
    torch::autograd::Engine::get_default_engine().queue_callback([=] {
      std::lock_guard<std::mutex> lock(this->mutex_);
      // Run callback with the current stream
      c10::OptionalStreamGuard currentStreamGuard{currentStream};
      if (should_collect_runtime_stats()) {
      // Check that all buckets were completed and had their work kicked off.
      TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());

4.5 direct call

_ rebuild_ The buckets function can also be called directly. For example, in the following case, forward is called once during the whole training period.

def forward(self, *inputs, **kwargs):
    with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.num_iterations += 1
        if self.ddp_uneven_inputs_config.ddp_join_enabled:
            ones = torch.ones(1, device=self.device)
            work = dist.all_reduce(ones, group=self.process_group, async_op=True)
            if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                # Active ranks schedule an allreduce with zeros, inactive
                # ranks schedule them with 1. If the result != 0 it
                # indicates at least one rank has terminated and we should
                # throw.
                zeros = torch.zeros(1, device=self.device)
                dist.all_reduce(zeros, group=self.process_group)
                should_throw_stop_iteration = zeros.item()
                if should_throw_stop_iteration:
                    raise RuntimeError(
                        "Detected at least one rank that exhausted inputs. Throwing across all ranks."

        # Calling _rebuild_buckets before forward compuation,
        # It may allocate new buckets before deallocating old buckets
        # inside _rebuild_buckets. To save peak memory usage,
        # call _rebuild_buckets before the peak memory usage increases
        # during forward computation.
        # This should be called only once during whole training period.
        # Make a direct call here
        if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): # set up
            logging.info("Reducer buckets have been rebuilt in this iteration.")

For another example, the Join method can also be called directly for reconstruction.

def join(
  									# Ignore other codes
                        # Some DDP process still needs to be joined.
                        if self.ddp_uneven_inputs_config.ddp_join_throw_on_early_termination:
                            # Schedule allreduce telling active ranks to terminate
                            ones = torch.ones(1, device=self.device)
                            dist.all_reduce(ones, group=self.process_group)
                            # Raising StopIteration doesn't throw error in python 3.6
                            # and throws RuntimeError in 3.7+ (PEP 479), so just
                            # raise RuntimeError here.
                            raise RuntimeError(
                                f"Rank {self._distributed_rank} exhausted all inputs."
                        if is_last_joiner:
                            is_last_joiner = False
                        # It will rebuild buckets only once during training period
                        # Call here.
                        # Schedule a corresponding broadcast if we are syncing module
                        # buffers in the forward pass.

Now that we have mentioned Join, let's take a look at this concept.

0x05 Join

Join is to solve the problem of uneven training data, that is, to allow some workers with less input (who have completed the join operation) to continue to perform collective communication with those workers that have not finished, which is a spoofing operation (Shadow).

5.1 origin

Behind DDP is the all reduce operation of several collective communication libraries, which completes the gradient synchronization between worker s. When the input of training data between ranges is uneven ly, DDP will hang. Because collective communication requires all the ranges in the process group to participate, if one range has few inputs, other ranges will hang or report errors (depending on the back end), and any class will encounter this problem in each iteration when performing synchronous collective communication.

Therefore, DDP provides a "Join" API, which is a context manager used in the training cycle of each rank. A rank with a small amount of data will exhaust the input in advance. At this time, it will give an illusion to the collective communication, so as to build a dummy all reduce to match with other ranks when the data is insufficient. How to create this illusion is specified by the registered hook.

The general idea is as follows:

                |             Data           |
                |   +--------+   +--------+  |
                |   |        |   | Empty  |  |
                |   |        |   |        |  |
                |   +-----+--+   +--------+  |
                |         |                  |
                |         |                  |
        +------------+    |               +------------+
        |            |    |               |            |
+---->  |    Model   |    |               |   Model    | <-----+
|       |            |    |               |            |       |
|       +------+-----+    |               +------+-----+       |
|              |          |                      |             |
|              |          |                      |             |
|              v          |                      v             |
|       +------+-----+    |             +--------+----------+  |
|       |  Forward   +<---+             | _JoinHook         |  |
|       |  (local)   |                  |                   |  |
|       +------+-----+                  |                   |  |
|              |                        |                   |  |
|              |                        |                   |  |
|              v                        | +---------------+ |  |
|       +------+-----+                  | | main_hook     | |  |
|       |  Backward  |                  | |               | |  |
|       |  (local)   |                  | |               | |  |
|       +------+-----+                  | |               | |  |
|              |                        | |               | |  |
|              |                        | |               | |  |
|              v                        | |               | |  |
|       +------+-----+                  | |               | |  |
|       | All-Reduce |     Sync grads   | |   All-Reduce  | |  |
|       |            | <--------------> | |   (Dummy)     | |  |
|       +------+-----+                  | |               | |  |
|              |                        | +---------------+ |  |
|              |                        +-------------------+  |
|              v                                 |             |
|     +--------+-------+                         |             |
|     | Update Weights |                         |             |
|     |                |                         |             |
|     +--------+-------+                         |             |
|              |                                 |             |
|              |                                 |             |
+--------------+                                 +-------------+

5.2 use

5.2.1 DistributedDataParallel

Join can be used together with DistributedDataParallel. For example, in the following example, two worker s will be started, namely rank 0 and rank 1. Rank 0 will get 5 inputs and rank 1 will get 6 inputs, which is input imbalance.

If Join is not used, rank 1 will die and hang when processing the sixth input. Because rank 0 has no relevant input, rank 1 can only wait. If Join is used, this problem will not occur and can be ended smoothly.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":

This will produce the following output (where print comes from ranks of level 0 and level 1 and can be sorted arbitrarily):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

5.2.2 ZeroRedundancyOptimizer

The Join context can cooperate with not only one class, but also multiple classes, such as the ZeroRedundancyOptimizer of PyTorch.

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

This will produce the same output as before. The significant change is that the ZeroRedundancyOptimizer instance needs to be passed in to Join().

Mechanisms such as zeroredundancy optimizer will also be analyzed later.

5.3 principle

In the latest document https://pytorch.org/tutorials/advanced/generic_join.html PyTorch gives a certain explanation, which we translate as follows.

For better use, we will introduce the Join class and the support classes Joinable and JoinHook.

Note: this part is in the v1.10.0 code.

5.3.1 Joinable

First, classes compatible with the Join context manager must inherit the abstract base class Joinable. In particular, Joinable must achieve:

  • join_hook(self, **kwargs) -> JoinHook

This will return the JoinHook instance Joinable to determine how the joined process should affect the collective communication of each iteration performed by Joinable.

  • join_device(self) -> torch.device

This returns the device that the Join context manager uses to perform collective communication, such as torch.device("cuda:0") or torch.device("cpu").

  • join_process_group(self) -> ProcessGroup

This returns the process group that the Join context manager uses to perform collective communication.

To sum up, JoinHook is responsible for specific actions_ Device and join_process_group is responsible for specific collective communication.

It should be noted that join_device and join_process_group is a required attribute that ensures that the context manager can schedule collective communication between "joined" and "not joined" processes. One usage is to use all reduce to calculate the number of "not joined" processes in each iteration. Another use is to implement throw_ on_ early_ The mechanism required for termination = true is explained below.

DistributedDataParallel and zeroredundancy optimizer have inherited Joinable and implemented the above methods, which is why we can use them directly in the previous examples.

class DistributedDataParallel(Module, Joinable):

class ZeroRedundancyOptimizer(Optimizer, Joinable):

DDP involves providing data, so it is understandable to inherit Joinable. Why does zeroredundancy optimizer also need to inherit? This is because zeroredundancy optimizer can cooperate with DDP, and there are collection operations inside zeroredundancy optimizer, so it needs to be managed by Join.

The Joinable class should ensure that the Joinable constructor is called because it initializes a JoinConfig instance, and the context manager uses JoinConfig internally to ensure correctness. JoinConfig will be in each Joinable_ join_ Save in the config field.


Next, let's break down the JoinHook class. JoinHook provides two entry points into the context manager:

  • main_hook(self) -> None

When there is a rank that has not been joined, each Join rank will call this hook repeatedly. Its purpose is to hide the collective communication performed by Joinable in each training iteration (for example, in a forward pass, reverse pass and optimizer step), that is, how the joined rank performs collective communication with the non joined rank.

  • post_hook(self, is_last_joiner: bool) -> None

Once all ranks are added, the hook will be called. It passes an additional bool parameter is_last_joiner, which indicates whether this rank is one of the last joined ranks. This parameter may be useful for synchronization. ZeroRedundancyOptimizer

We use the built-in zeroredundancy optimizer main hook to give a specific example of a hook: because the added rank is still responsible for updating and synchronizing its parameter fragments, the main hook still executes the optimizer steps.

class _ZeROJoinHook(_JoinHook):
    def __init__(self, zero):
        assert isinstance(zero, ZeroRedundancyOptimizer), \
            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " \
            "instance as the state"
        self.zero = zero

    def main_hook(self):
        Performs an optimizer step, which updates the joined process's shard of
        the parameters and broadcasts those parameters.

The step function is as follows:

def step(
    closure: Optional[Callable[[], float]] = None,
    **kwargs: Any,
) -> Optional[float]:
    _Join.notify_join_context(self) # You'll be notified here
    # Check if the model trainability has changed
    is_trainable_mask = self._get_is_trainable_mask()
    if is_trainable_mask != self._is_trainable_mask:
        self._is_trainable_mask = is_trainable_mask

    # Sync the exposed `param_groups` attributes to the local optimizer in
    # case they have been updated
    self._sync_param_groups(self.param_groups, self.optim.param_groups)

    # Run the optimizer step on this shard only
    if closure is not None:
        loss = self.optim.step(closure=closure, **kwargs)  # type: ignore[call-arg]
        loss = self.optim.step(**kwargs)

    # Sync all of the updated parameter shards across the ranks

    # Sync any updated attributes in the local optimizer to the exposed
    # `param_groups`
    self._sync_param_groups(self.optim.param_groups, self.param_groups)

    return loss

Let's look at the DistributedDataParallel:

  • main_hook will still do a series of related operations to deceive other rank.
  • The post hook broadcasts the final updated model from one of the last added ranges to ensure that the model is the same in all ranges.
class _DDPJoinHook(_JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        Sets config variables for internal usage.
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size

    def main_hook(self):
        Shadows the DDP collective communication operations in the forward and
        backward passes.
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass

        # Check if need to sync in the backward pass
        work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
        should_sync_backwards = work.result()[0].item() != 0
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:

        # Rebuilt parameters are pushed only once during a training period

    def post_hook(self, is_last_joiner: bool):
        Syncs the final model to ensure that the model is the same across all

_ sync_final_model the latest model will be broadcast here.

# When running in join model, agrees upon a common rank and broadcast model
# parameters to all other ranks.
def _sync_final_model(self, is_last_joiner):
    # Agree upon the process that will be the authoritative model copy.
    # The current rank is a candidate for being the authoritative copy if
    # is_last_joiner=True. We break ties via picking the larger rank.
    self._authoritative_rank = self._find_common_rank(
        self._distributed_rank, is_last_joiner

5.3.3 Join

Finally, let's see how these basic classes fit into the Join class itself.

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

As we saw in the previous example, the constructor receives a list of joinables participating in the training cycle. These should be classes that perform collective communication in each iteration.

enable is a bool type. If you know there will be no uneven input, you can set it to False. In this case, the context manager becomes similar to contextlib.nullcontext(). This may also disable the join related calculation in the participating join list.

throw_on_early_termination is a bool type that can be set to True so that each level throws an exception when uneven input is detected. This is useful for situations that do not meet the requirements of the context manager, usually when collective communications from different classes can be interleaved arbitrarily, such as when distributed dataparallel is used with a model with SyncBatchNorm layer. In this case, this parameter should be set to True so that the application logic can catch the exception and determine how to proceed.

  • The core logic appears in the__ exit__ () method, this method will call the main hook of each Joinable loop when there are unincorporated rank, and then call their post hook once all rank join. Both the main hook and the back hook iterate in the order passed in by Joinables.
  • The context manager requires a heartbeat from an unincorporated process. Therefore, each Joinable class should call Join.notify_ before the collective communication of each iteration. join_ context() . The context manager will ensure that only the first incoming joinable actually sends a heartbeat.

5.4 examples

Let's take a concrete look through an example. In the following code, each rank will print (1) the number of inputs of all ranks seen before Join, and (2) the total number of inputs of all ranks.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"

class CounterJoinHook(JoinHook):
    Join hook for :class:`Counter`.

        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
    def __init__(
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        t = torch.zeros(1, device=self.counter.device)

    def post_hook(self, is_last_joiner: bool):
        Synchronizes the max count across all :class:`Counter` s if
        if not self.sync_max_count:
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        t = torch.ones(1, device=self.device).float()
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    def join_device(self) -> torch.device:
        return self.device

    def join_process_group(self):
        return self.process_group

    # Determine the rank of the last join. Since more than one rank may be added later, select the rank with the largest rank to synchronize  
    def find_common_rank(self, rank, to_consider):
        Returns the max rank of the ones to consider over the process group.
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":

Since rank 0 sees 5 inputs and rank 1 sees 6, the following outputs are generated:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

Some key points to emphasize:

  • The Counter instance performs an all reduce operation in each iteration, so:
    • For the joined rank, its main hook also executes a single all reduce to shadow it for the overall communication. Note that this all reduce calls a tensor of 0, so it has no impact on the overall result.
    • Other rank s that are not joined will think that this is still a correct full set operation.
    • This handles uneven input.
  • Counter class in its__ call__ Call Join.notify at the beginning of the () method_ Join_ Context (), because this is the place for each collection operation (all reduce), you need to notify the context manager here. In this example, there is no Join (the finished rank will not be called here).
  • 'is_ last_ The joiner 'parameter is used to determine the broadcast source in post hooks.
  • We will sync_ max_ The count keyword parameter is passed to the context manager, which forwards it to the join hook of 'Counter'.
  • In post hooks, self.counter.max_count to broadcast.

0xFF reference

pytorch distributed series 3 - what does torch.utils.data.distributed.DistributedSampler do during distributed training?

pytorch distributed series 1 -- find out the environment variables related to torch.distributed.launch

How does pytorch distributed series 2 - distributed data parallel synchronize?

Summary of personal practice of pytorch (distributed) data parallel -- dataparallel / distributed dataparallel

nn.DataParallel of pytoch



PyTorch source code interpretation of distributed training to understand?

Practical tutorial | PyTorch AutoGrad C + + layer implementation

PYTORCH automatic differentiation (I)

How does PyTorch accelerate data parallel training? Uncover the secrets of distributed Secrets

pytorch distributed training (II init_process_group)




Interpretation of PyTorch source code DP & DDP: model parallel and distributed training analysis

parameter and buffer in pytoch model

[PyTorch Developer Day 2020] PyTorch distributed data parallelism (DDP)

[Chinese subtitle] deeply understand the Hook mechanism in PyTorch

[Chinese subtitle] in depth interpretation of pytoch autograd


Talk about zero redundancy optimizer and Join in torch 1.10

Tags: Machine Learning

Posted on Wed, 24 Nov 2021 21:05:30 -0500 by noobcody