c10::intrusive_ptr ProcessGroupUCC::allgather()

in src/torch_ucc.cpp [764:866]


c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allgather(
    std::vector<std::vector<at::Tensor>>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const AllgatherOptions& /* unused */) {
  if (size_ == 1) {
      outputTensors[0][0].copy_(inputTensors[0]);
      return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
                OpType::ALLGATHER,
                torch_ucc_config.enable_profiling ? "ucc:allgather" : nullptr);
  }
  auto& tensor = inputTensors[0];
  check_device(tensor.device(), outputTensors[0][0].device());
  initComm(tensor.device());

  if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
    AllgathervWorkData* data = new AllgathervWorkData(size_);
    for (int i = 0; i < size_; i++) {
      data->recv_lengths[i] = tensor.element_size() * tensor.numel();
      data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
    }
    ucc_coll_args_t coll;
    coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
    coll.flags =
        UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
    coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
    coll.src.info.buffer = tensor.data_ptr();
    coll.src.info.count = tensor.element_size() * tensor.numel();
    coll.src.info.datatype = UCC_DT_UINT8;
    coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
    coll.dst.info_v.buffer = nullptr;
    coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
    coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
    coll.dst.info_v.datatype = UCC_DT_UINT8;
    coll.dst.info_v.mem_type =
        to_ucc_memType(outputTensors[0][0].device().type());
    SAVE_TENSORS(inputTensors, data->src);
    SAVE_TENSORS(outputTensors[0], data->dst);

    return collective_post(
        OpType::ALLGATHER,
        []() {},
        []() {},
        coll,
        std::unique_ptr<WorkData>(data),
        tensor.device(),
        outputTensors[0],
        "ucc:allgatherv");
  } else {
    WorkData* data = new WorkData();
    std::vector<at::Tensor> flat_output(outputTensors.size());
    for (size_t i = 0; i < outputTensors.size(); i++) {
      TORCH_CHECK(outputTensors[i].size() == outputTensors.size() * size_,
        "Tensor output list is not valid for the number of participants");
        flat_output[i] = c10d::newLikeFlat(outputTensors, i);
    }
    SAVE_TENSORS(flat_output, data->flat);
    ucc_coll_args_t coll;
    coll.mask = 0;
    coll.flags = 0;
    coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
    coll.src.info.buffer = tensor.data_ptr();
    coll.src.info.count = tensor.numel();
    coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
    coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
    coll.dst.info.buffer = flat_output[0].data_ptr();
    coll.dst.info.count = flat_output[0].numel();
    coll.dst.info.datatype = ucc_dtype_map.at(flat_output[0].scalar_type());
    coll.dst.info.mem_type =
        to_ucc_memType(outputTensors[0][0].device().type());

    auto copy_from_flat = [&] {
      bool asyncCopy = false;
  #ifdef USE_CUDA
      bool isCuda = outputTensors[0][0].device().is_cuda();;
  #endif
      for (size_t i = 0; i < outputTensors.size(); i++) {
        auto inumel = inputTensors[i].numel();
        for (size_t j = 0; j < outputTensors[i].size(); j++) {
          TORCH_CHECK(
            (outputTensors[i][j].numel() == inumel),
            "Tensor operand counts must be same");
  #ifdef USE_CUDA
          if (isCuda) {
            c10::cuda::CUDACachingAllocator::recordStream(
              outputTensors[i][j].storage().data_ptr(), (*stream));
            asyncCopy = true;
          }
  #endif
          outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
        }
      }
    };
    return collective_post(
        OpType::ALLGATHER,
        []() {},
        copy_from_flat,
        coll,
        std::unique_ptr<WorkData>(data),
        tensor.device(),
        outputTensors[0],
        "ucc:allgather");
  }
}