c10::intrusive_ptr MachineFairring::reduceScatter()

in fairring/machine.cc [244:261]


c10::intrusive_ptr<c10::ivalue::Future> MachineFairring::reduceScatter(
    std::vector<TensorPair> tensors) {
  MY_CHECK(tensors.size() == devices_.size());
  for (const auto deviceOffset : c10::irange(tensors.size())) {
    MY_CHECK(tensors[deviceOffset].output.device() == devices_[deviceOffset]);
  }

  c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures(
      c10::ListType::ofTensors());
  futures.reserve(nodes_.size());
  for (const auto idx : c10::irange(nodes_.size())) {
    const std::unique_ptr<DeviceFairring>& node = nodes_[idx];
    futures.push_back(
        node->reduceScatter(tensors[idx].input, tensors[idx].output));
  }

  return mergeMultiDeviceFutures(std::move(futures));
}