void DoAllGatherAnyShape()

in lib/distributed_runtime/kernels.cc [704:830]


void DoAllGatherAnyShape(const ExecutionContext& exec_ctx,
                         AsyncValueRef<DistributedContext> dist_ctx,
                         const InstanceKey& instance_key,
                         const CollectiveGroup& collective_group,
                         const int my_index, const DenseHostTensor& in_tensor,
                         AsyncValueRef<DenseHostTensor> out_tensor,
                         DenseHostTensor& shapes_tensor,
                         MutableDHTArrayView<Index> shape_tensor_view,
                         AsyncValueRef<Chain> out_chain, size_t axis,
                         size_t kGroupSize, size_t kRank) {
  llvm::SmallVector<llvm::SmallVector<size_t, 4>, 4> offsets;
  llvm::SmallVector<size_t, 4> step_sizes;
  llvm::SmallVector<Index, 4> dimensions;
  llvm::SmallVector<Index, 4> tensor_sizes;
  size_t gathered_dimension = 0;
  DHTArrayView<size_t> shapes_array(&shapes_tensor);
  // This is to check that all tensors have the same rank.
  // E.g. AllGather on tensors of shape 2x2 and 2x2x2 does not make sense.
  if (shapes_array.NumElements() / kGroupSize != kRank) {
    out_chain.SetError("All workers must have tensors of the same rank.");
    return;
  }
  // This goes through the participating shapes and compute the following.
  // - dimension of output tensor
  // - number of elements in each participating tensor
  // - step_size for each tensor
  for (size_t i = 0; i < kGroupSize; ++i) {
    size_t num_element = 1;
    size_t step_size = 1;
    for (size_t j = 0; j < kRank; ++j) {
      // This refers to i-th tensor's j-th dimension.
      // Example:
      //   Suppose there are 2 tensors to be gathered.
      //   T0 has shape 1x2 and T1 has shape 1x3.
      //   Then, we have
      //      shapes_array = [1,2,1,3]
      //   dim is used to access T0's 1,2 and T1's 1,3.
      size_t dim = kRank * i + j;
      if (j != axis) {
        if (shapes_array[dim] != shape_tensor_view.Elements()[j]) {
          out_chain.SetError(
              "All dimensions in the input must be equal except the axis");
          return;
        }
      } else {
        gathered_dimension += shapes_array[dim];
      }
      num_element *= shapes_array[dim];
      if (j >= axis) {
        step_size *= shapes_array[dim];
      }
    }
    tensor_sizes.push_back(num_element);
    step_sizes.push_back(step_size);
  }
  // Create a vector of output dimension
  for (size_t l = 0; l < kRank; ++l) {
    if (l != axis) {
      dimensions.push_back(shape_tensor_view.Elements()[l]);
    } else {
      dimensions.push_back(gathered_dimension);
    }
  }
  // Compute an interval (distance between offsets).
  // See line 658 for more details.
  size_t interval = 1;
  for (size_t m = axis; m < kRank; ++m) {
    interval *= dimensions[m];
  }
  // For each participating tensor, compute offsets (positions to be copied to
  // in output tensor)
  size_t pos = 0;
  for (size_t n = 0; n < kGroupSize; ++n) {
    llvm::SmallVector<size_t, 4> offset;
    size_t num_offsets = tensor_sizes[n] / step_sizes[n];
    size_t each_offset = pos;
    for (size_t o = 0; o < num_offsets; ++o) {
      offset.push_back(each_offset);
      each_offset += interval;
    }
    offsets.push_back(offset);
    pos += step_sizes[n];
  }
  // Create an output tensor
  TensorShape shape(dimensions);
  TensorMetadata md(in_tensor.metadata().dtype, shape);
  auto output_tensor =
      DenseHostTensor::MakeConstructedAsyncValueRef(md, exec_ctx.host());

  auto done = [output_tensor = output_tensor.CopyRef(),
               out_chain = std::move(out_chain),
               out_tensor = std::move(out_tensor),
               dist_ctx = dist_ctx.CopyRef()](Error e) mutable {
    if (e) {
      out_chain.SetError(e);
    } else {
      out_chain.emplace();
      out_tensor.emplace(std::move(output_tensor.get()));
    }
  };

  auto refcounted_done = TakeRef(
      new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
                              done = std::move(done)](Error e) mutable {
        if (host->IsInWorkerThread()) {
          done(std::move(e));
        } else {
          EnqueueWork(exec_ctx,
                      [done = std::move(done), e = std::move(e)]() mutable {
                        done(std::move(e));
                      });
        }
      }));

  // Do a final AllGather
  EnqueueWork(exec_ctx, [exec_ctx, my_index, instance_key, axis,
                         collective_group, dist_ctx = dist_ctx.CopyRef(),
                         in_tensor_ref = in_tensor.CopyRef(),
                         output_tensor = output_tensor.CopyRef(),
                         refcounted_done = refcounted_done,
                         offsets = std::move(offsets),
                         step_sizes = std::move(step_sizes)] {
    DoAllGather<T>(exec_ctx, dist_ctx.CopyRef(), instance_key, collective_group,
                   my_index, in_tensor_ref.CopyRef(), output_tensor.get(),
                   refcounted_done, axis, offsets, step_sizes);
  });
}