ReduceRangeFunction genLocalReduceFunction()

in gloo/allreduce.cc [44:81]


ReduceRangeFunction genLocalReduceFunction(
    const BufferVector& in,
    const BufferVector& out,
    size_t elementSize,
    ReductionFunction fn) {
  if (in.size() > 0) {
    if (in.size() == 1) {
      return [&in, &out](size_t offset, size_t length) {
        memcpy(
            static_cast<uint8_t*>(out[0]->ptr) + offset,
            static_cast<const uint8_t*>(in[0]->ptr) + offset,
            length);
      };
    } else {
      return [&in, &out, elementSize, fn](size_t offset, size_t length) {
        fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
           static_cast<const uint8_t*>(in[0]->ptr) + offset,
           static_cast<const uint8_t*>(in[1]->ptr) + offset,
           length / elementSize);
        for (size_t i = 2; i < in.size(); i++) {
          fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
             static_cast<const uint8_t*>(out[0]->ptr) + offset,
             static_cast<const uint8_t*>(in[i]->ptr) + offset,
             length / elementSize);
        }
      };
    }
  } else {
    return [&out, elementSize, fn](size_t offset, size_t length) {
      for (size_t i = 1; i < out.size(); i++) {
        fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
           static_cast<const uint8_t*>(out[0]->ptr) + offset,
           static_cast<const uint8_t*>(out[i]->ptr) + offset,
           length / elementSize);
      }
    };
  }
}