in src/accumulator.cc [855:962]
void reduceImpl(int batchSize) {
if (!wantsGradients()) {
fatal("reduceGradients/skipGradients called while wantsGradients() is false");
}
log.debug("Reduce %d\n", batchSize);
glock l(h->mutex);
size_t index = nextGradientReductionIndex;
std::shared_ptr<ReduceGradientsContainer> target = gradientReductions[index];
if (target && target->reduceStarted && !target->reduceDone) {
fatal("reduceImpl internal error, reduce already started!");
}
if (!target || target->reduceDone) {
target = gradientReductions[index] = std::make_shared<ReduceGradientsContainer>(this);
}
if (nextGradientReductionIndex == gradientReductions.size() - 1) {
nextGradientReductionIndex = 0;
} else {
++nextGradientReductionIndex;
}
isCopyingGradients = true;
std::optional<rpc::CUDAStream> stream;
if (gradsOnCuda) {
stream.emplace(rpc::getCurrentCUDAStream());
}
target->syncId = h->syncId;
// async.run([this, batchSize, target, index, stream = std::move(stream), syncId = h->syncId]() mutable noexcept {
Dtor dtor = [&] {
// std::lock_guard l(h->mutex);
actuallyZeroGradients();
isCopyingGradients = false;
};
std::optional<rpc::CUDAStreamGuard> sg;
if (stream) {
sg.emplace(*stream);
}
rpc::AutoGradMode ng(false);
bool synchronize = false;
if (batchSize) {
++target->data.numGradients;
target->data.batchSize += batchSize;
bool add = true;
if (target->data.gradients.empty()) {
target->data.gradients = allocateGradients();
add = false;
}
auto& targetGradients = target->data.gradients;
size_t i = 0;
std::vector<rpc::Tensor> addGrads;
for (auto& v : h->modelParameters) {
auto grad = v.grad();
if (grad.defined()) {
if (i == targetGradients.size()) {
fatal("grads grew?");
}
if (add) {
addGrads.push_back(grad.to(targetGradients[i].device(), true));
} else {
targetGradients[i].copy_(grad, true);
}
++i;
}
}
if (i != targetGradients.size()) {
fatal("grads shrank?");
}
synchronize = gradsOnCuda;
if (synchronize && stream) {
stream->synchronize();
}
if (add) {
for (size_t i = 0; i != addGrads.size(); ++i) {
targetGradients[i].add_(addGrads[i]);
}
if (synchronize && stream) {
stream->synchronize();
}
}
} else {
++target->data.numSkipped;
}
if (synchronize && stream) {
stream->synchronize();
}
try {
// std::lock_guard l(h->mutex);
if (target->syncId == h->syncId && target->syncId == group->syncId) {
if (target->isCounting) {
log.debug("Already counting!\n");
target->wantsMoreCounting = true;
} else {
log.debug("Start new count!\n");
startCount(index, std::move(target));
}
}
} catch (const std::exception& e) {
fatal("exception %s\n", e.what());
}
//});
}