inline bool is_compatible_weight()

in include/checks.h [30:44]


inline bool is_compatible_weight(const at::Tensor& x, const at::Tensor& w) {
  // Dimensions check
  bool success = w.ndimension() == 1;
  success &= x.size(1) == w.size(0);

  // Typing check
  if (x.scalar_type() == at::ScalarType::Half) {
    success &= (w.scalar_type() == at::ScalarType::Half) ||
        (w.scalar_type() == at::ScalarType::Float);
  } else {
    success &= x.scalar_type() == w.scalar_type();
  }

  return success;
}