void AllGatherFixedShape()

in lib/distributed_runtime/kernels.cc [557:701]


void AllGatherFixedShape(Argument<DistributedContext> dist_ctx,
                         Argument<std::string> collective_group_name,
                         Argument<InstanceKey> instance_key,
                         Argument<DenseHostTensor> in_tensor,
                         Argument<DenseHostTensor> out_tensor,
                         Argument<size_t> axis, Argument<Chain> in_chain,
                         Result<Chain> out_chain,
                         const ExecutionContext& exec_ctx) {
  auto out_chain_indirect = out_chain.Allocate();
  const auto& collective_group =
      dist_ctx->GetCollectiveGroup(collective_group_name.get());
  int my_index =
      FindMyIndex(collective_group.members, dist_ctx->GetTaskHandle());
  if (my_index == -1) {
    out_chain_indirect.SetError(
        "This worker is not part of the collective group ");
    return;
  }
  const auto kGroupSize = collective_group.members.size();
  auto done = [out_tensor = out_tensor.ValueRef(), instance_key = *instance_key,
               out_chain = std::move(out_chain_indirect),
               dist_ctx = dist_ctx.ValueRef()](Error e) mutable {
    if (e) {
      out_chain.SetError(e);
    } else {
      out_chain.emplace();
    }
  };
  auto refcounted_done = TakeRef(
      new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
                              done = std::move(done)](Error e) mutable {
        if (host->IsInWorkerThread()) {
          done(std::move(e));
        } else {
          EnqueueWork(exec_ctx,
                      [done = std::move(done), e = std::move(e)]() mutable {
                        done(std::move(e));
                      });
        }
      }));

  // Compute offsets and step_sizes for all workers
  // offsets: offsets in output tensor for each tensor
  // step_sizes: # of elements to be copied for each offset
  // Example:
  //    t0 = [[0,1,2],[3,4,5]] dimension = 2x3
  //    t0 = [[60,70,80],[90,100,110] dimension = 2x3
  llvm::SmallVector<llvm::SmallVector<size_t, 4>, 4> offsets;
  llvm::SmallVector<size_t, 4> step_sizes;
  size_t num_elements = in_tensor->NumElements();
  if (*axis == 0) {
    // if axis = 0, whole tensor (# elements = 6) is concat after one another
    // offsets = [[0],[6]]
    // step_sizes = [6,6]
    // result = [[0,1,2],[3,4,5],[60,70,80],[90,100,110]] dimension = 4x3
    size_t pos = 0;
    for (size_t i = 0; i < kGroupSize; ++i) {
      step_sizes.push_back(num_elements);
      offsets.push_back({pos});
      pos += num_elements;
    }
  } else {
    // otherwise, more work to be done to figure out the offsets and step_sizes
    // as we no longer simply concat the whole tensor.
    // offsets = [[0,6],[3,9]]
    // step_sizes = [3,3]
    // For worker_0's tensor, it is split into 2 chunks of 3 elements each.
    //   each chunk should be copied into out_tensor at position 0 and 6.
    // For worker_1's tensor, it is split into 2 chunks of 3 elements each.
    //   each chunk should be copied into out_tensor at position 3 and 9.
    // result = [[0,1,2,60,70,80],[3,4,5,90,100,110]] dimension = 2x6
    llvm::SmallVector<Index, 4> in_dimension;
    in_tensor->shape().GetDimensions(&in_dimension);
    llvm::SmallVector<Index, 4> out_dimension;
    out_tensor->shape().GetDimensions(&out_dimension);
    // step_size refers to # of elements of input tensor to be copied into the
    // output tensor for each offset. step_size is determined by the axis.
    //
    // step_size = dimension[axis] * dimension[axis+1] * ... * dimension[n-1]
    // when n is the # of dimensions.
    // In other words, step_size is the product of dimensions starting from the
    // axis to the end.
    //
    // Why this works:
    // To gather on axis k means to gather all elements starting from that axis.
    //
    // Example:
    //    t0 has dimension 2x3
    //    t1 has dimension 2x3.
    //    Gathering on axis 0: step_size = dimension[0] * dimension[1] = 2*3 = 6
    //      This means the whole tensor is copied into the output after another.
    //    Gathering on axis 1: step_size = dimension[1] = 3
    //      This means 3 elements from each tensor is copied for each offset.
    size_t step_size = 1;
    // interval refers to the distance between offsets.
    // interval is determined by the axis and output dimension.
    //
    // interval = out_dim[axis] * out_dim[axis+1] * ... * out_dim[n-1]
    // when n is the # of dimensions and
    // out_dim is the dimension of output tensor
    //
    // Why we need this:
    // When gathering on a non-zero axis, a tensor is split into chunks. Each
    // chunk, whose size is step_size, is copied into the output tensor at some
    // offset. Each offset is some distance apart from each other. The distance
    // is referred to as interval.
    size_t interval = 1;
    for (size_t j = *axis; j < in_dimension.size(); ++j) {
      if (j != *axis && in_dimension[j] != out_dimension[j]) {
        out_chain_indirect.SetError(
            "Incorrect output dimension. All dimensions in the output must be "
            "equal except the axis.");
        return;
      }
      step_size *= in_dimension[j];
      interval *= out_dimension[j];
    }
    size_t pos = 0;
    for (size_t i = 0; i < kGroupSize; ++i) {
      step_sizes.push_back(step_size);
      size_t num_offsets = num_elements / step_size;
      llvm::SmallVector<size_t, 4> offset;
      size_t each_offset = pos;
      for (size_t l = 0; l < num_offsets; ++l) {
        offset.push_back(each_offset);
        each_offset += interval;
      }
      offsets.push_back(offset);
      pos += step_size;
    }
  }

  EnqueueWork(exec_ctx,
              [exec_ctx, my_index, instance_key = *instance_key, axis = *axis,
               collective_group, dist_ctx = dist_ctx.ValueRef(),
               in_tensor_ref = in_tensor.ValueRef(),
               out_tensor_ref = out_tensor.ValueRef(),
               refcounted_done = refcounted_done, offsets = std::move(offsets),
               step_sizes = std::move(step_sizes)] {
                DoAllGather<T>(exec_ctx, dist_ctx.CopyRef(), instance_key,
                               collective_group, my_index, in_tensor_ref.get(),
                               out_tensor_ref.get(), refcounted_done, axis,
                               offsets, step_sizes);
              });
}