void reduceImpl()

in src/accumulator.cc [855:962]


  void reduceImpl(int batchSize) {
    if (!wantsGradients()) {
      fatal("reduceGradients/skipGradients called while wantsGradients() is false");
    }

    log.debug("Reduce %d\n", batchSize);

    glock l(h->mutex);

    size_t index = nextGradientReductionIndex;
    std::shared_ptr<ReduceGradientsContainer> target = gradientReductions[index];
    if (target && target->reduceStarted && !target->reduceDone) {
      fatal("reduceImpl internal error, reduce already started!");
    }
    if (!target || target->reduceDone) {
      target = gradientReductions[index] = std::make_shared<ReduceGradientsContainer>(this);
    }

    if (nextGradientReductionIndex == gradientReductions.size() - 1) {
      nextGradientReductionIndex = 0;
    } else {
      ++nextGradientReductionIndex;
    }

    isCopyingGradients = true;

    std::optional<rpc::CUDAStream> stream;
    if (gradsOnCuda) {
      stream.emplace(rpc::getCurrentCUDAStream());
    }

    target->syncId = h->syncId;

    // async.run([this, batchSize, target, index, stream = std::move(stream), syncId = h->syncId]() mutable noexcept {
    Dtor dtor = [&] {
      // std::lock_guard l(h->mutex);
      actuallyZeroGradients();
      isCopyingGradients = false;
    };
    std::optional<rpc::CUDAStreamGuard> sg;
    if (stream) {
      sg.emplace(*stream);
    }
    rpc::AutoGradMode ng(false);
    bool synchronize = false;
    if (batchSize) {
      ++target->data.numGradients;
      target->data.batchSize += batchSize;
      bool add = true;
      if (target->data.gradients.empty()) {
        target->data.gradients = allocateGradients();
        add = false;
      }
      auto& targetGradients = target->data.gradients;
      size_t i = 0;
      std::vector<rpc::Tensor> addGrads;
      for (auto& v : h->modelParameters) {
        auto grad = v.grad();
        if (grad.defined()) {
          if (i == targetGradients.size()) {
            fatal("grads grew?");
          }
          if (add) {
            addGrads.push_back(grad.to(targetGradients[i].device(), true));
          } else {
            targetGradients[i].copy_(grad, true);
          }
          ++i;
        }
      }
      if (i != targetGradients.size()) {
        fatal("grads shrank?");
      }
      synchronize = gradsOnCuda;
      if (synchronize && stream) {
        stream->synchronize();
      }
      if (add) {
        for (size_t i = 0; i != addGrads.size(); ++i) {
          targetGradients[i].add_(addGrads[i]);
        }
        if (synchronize && stream) {
          stream->synchronize();
        }
      }
    } else {
      ++target->data.numSkipped;
    }
    if (synchronize && stream) {
      stream->synchronize();
    }

    try {
      // std::lock_guard l(h->mutex);
      if (target->syncId == h->syncId && target->syncId == group->syncId) {
        if (target->isCounting) {
          log.debug("Already counting!\n");
          target->wantsMoreCounting = true;
        } else {
          log.debug("Start new count!\n");
          startCount(index, std::move(target));
        }
      }
    } catch (const std::exception& e) {
      fatal("exception %s\n", e.what());
    }
    //});
  }