in gloo/allreduce.cc [44:81]
ReduceRangeFunction genLocalReduceFunction(
const BufferVector& in,
const BufferVector& out,
size_t elementSize,
ReductionFunction fn) {
if (in.size() > 0) {
if (in.size() == 1) {
return [&in, &out](size_t offset, size_t length) {
memcpy(
static_cast<uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(in[0]->ptr) + offset,
length);
};
} else {
return [&in, &out, elementSize, fn](size_t offset, size_t length) {
fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(in[0]->ptr) + offset,
static_cast<const uint8_t*>(in[1]->ptr) + offset,
length / elementSize);
for (size_t i = 2; i < in.size(); i++) {
fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(in[i]->ptr) + offset,
length / elementSize);
}
};
}
} else {
return [&out, elementSize, fn](size_t offset, size_t length) {
for (size_t i = 1; i < out.size(); i++) {
fn(static_cast<uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(out[0]->ptr) + offset,
static_cast<const uint8_t*>(out[i]->ptr) + offset,
length / elementSize);
}
};
}
}