c10::intrusive_ptr MachineFairring::allReduce()

in fairring/machine.cc [222:242]


c10::intrusive_ptr<c10::ivalue::Future> MachineFairring::allReduce(
    c10d::ReduceOp opType,
    std::vector<at::Tensor> tensors) {
  // FIXME Support more operation types
  MY_CHECK(opType == c10d::ReduceOp::SUM);

  MY_CHECK(tensors.size() == devices_.size());
  for (const auto deviceOffset : c10::irange(tensors.size())) {
    MY_CHECK(tensors[deviceOffset].device() == devices_[deviceOffset]);
  }

  c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures(
      c10::ListType::ofTensors());
  futures.reserve(nodes_.size());
  for (const auto idx : c10::irange(nodes_.size())) {
    const std::unique_ptr<DeviceFairring>& node = nodes_[idx];
    futures.push_back(node->allReduce(opType, tensors[idx]));
  }

  return mergeMultiDeviceFutures(std::move(futures));
}