inline void compareShape()

in rela/tensor_dict.h [13:51]


inline void compareShape(const TensorDict& src, const TensorDict& dest) {
  if (src.size() != dest.size()) {
    std::cout << "src.size()[" << src.size() << "] != dest.size()[" << dest.size() << "]"
              << std::endl;
    std::cout << "src keys: ";
    for (const auto& p : src)
      std::cout << p.first << " ";
    std::cout << "dest keys: ";
    for (const auto& p : dest)
      std::cout << p.first << " ";
    std::cout << std::endl;
    assert(false);
  }

  for (const auto& name2tensor : src) {
    const auto& name = name2tensor.first;
    const auto& srcTensor = name2tensor.second;
    // std::cout << "in copy: trying to get: " << name << std::endl;
    // std::cout << "dest map keys" << std::endl;
    // printMapKey(dest);
    const auto& destTensor = dest.at(name);
    // if (destTensor.sizes() != srcTensor.sizes()) {
    //   std::cout << "copy size-mismatch: "
    //             << destTensor.sizes() << ", " << srcTensor.sizes() <<
    //             std::endl;
    // }
    if (destTensor.sizes() != srcTensor.sizes()) {
      std::cout << name << ", dstSize: " << destTensor.sizes()
                << ", srcSize: " << srcTensor.sizes() << std::endl;
      assert(false);
    }

    // if (destTensor.dtype() != srcTensor.dtype()) {
    //   std::cout << name << ", dstType: " << destTensor.dtype()
    //             << ", srcType: " << srcTensor.dtype() << std::endl;
    //   assert(false);
    // }
  }
}