void checkTensors()

in benchmark/benchmark.cc [163:201]


  void checkTensors() {
    for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) {
      // at::Tensor expected = at::arange(
      //                           buckets_[bucketIdx].numel(),
      //                           c10::TensorOptions()
      //                               .dtype(c10::kFloat)
      //                               .device(c10::Device(c10::kCUDA, 0))) *
      //     static_cast<int64_t>(numMachines_ * numDevicesPerMachine_);
      at::Tensor expected =
          torch::full(
              {1},
              static_cast<float>(numMachines_ * numDevicesPerMachine_),
              c10::TensorOptions()
                  .dtype(c10::kFloat)
                  .device(c10::Device(c10::kCUDA, 0)))
              .expand({buckets_[bucketIdx].numel()});

      if (!buckets_[bucketIdx].allclose(expected)) {
        throw std::runtime_error("Bad result");
      }
      // at::Tensor closeness =
      //     buckets_[bucketIdx].isclose(expected).logical_not();
      // at::Tensor nonCloseIndices = closeness.nonzero();
      // if (nonCloseIndices.size(0) > 0) {
      //   LOG(ERROR) << "In bucket " << bucketIdx << " which starts at 0x"
      //              << std::hex
      //              << reinterpret_cast<uintptr_t>(
      //                     buckets_[bucketIdx].data_ptr())
      //              << std::dec << " found non-close value at index "
      //              << nonCloseIndices[0].item<int64_t>() << " which has value "
      //              << buckets_[bucketIdx][nonCloseIndices[0].item<int64_t>()]
      //                     .item<float>()
      //              << " instead of "
      //              << expected[nonCloseIndices[0].item<int64_t>()].item<float>()
      //              << " and there are " << nonCloseIndices.size(0)
      //              << " non-close values in total ";
      // }
    }
  }