TensorDict get()

in rela/batcher.h [147:198]


  TensorDict get() {
    std::unique_lock<std::mutex> lk(mNextSlot_);
    cvGetBatch_.wait(lk, [this] {
      return (nextSlot_ > 0 && numActiveWrite_ == 0) || exit_;
    });

    if (exit_) {
      return TensorDict();
    }

    // TensorDict batch;
    // for (const auto& kv : buffer_) {
    //   batch[kv.first] = kv.second.narrow_copy(batchdim_, 0, nextSlot_).contiguous();
    // }
    int bsize = nextSlot_;
    nextSlot_ = 0;
    // assert previous reply has been handled
    assert(filledReply_ == nullptr);
    std::swap(fillingBuffer_, filledBuffer_);
    std::swap(fillingReply_, filledReply_);
    fillingReply_ = std::make_shared<FutureReply>();

    // assert currentReply has been handled
    // assert(currentReply_ == nullptr);
    // currentreply_ = std::move(nextReply_);
    // nextReply_ = std::make_shared<FutureReply>(batchdim_);

    lk.unlock();
    cvNextSlot_.notify_all();

    TensorDict batch;
    for (const auto& kv : filledBuffer_) {
      batch[kv.first] = kv.second.narrow(0, 0, bsize).contiguous();
      // batch[kv.first] = kv.second.narrow_copy(0, 0, batchsize_).contiguous();
    }

    sumBatchsize_ += bsize;
    batchCount_ += 1;
    if (batchCount_ % 5000 == 0) {
      /*
      if (sumBatchsize_ / batchCount_ > 100) {
        std::cout << ">>>>>>>>>>>>>>>.batchcount: " << (int64_t)this << std::endl;
        std::cout << sumBatchsize_ / (float)batchCount_ << std::endl;
        std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>"<< std::endl;
      }
      */
      sumBatchsize_ = 0;
      batchCount_ = 0;
    }

    return batch;
  }