in include/checks.h [30:44]
inline bool is_compatible_weight(const at::Tensor& x, const at::Tensor& w) {
// Dimensions check
bool success = w.ndimension() == 1;
success &= x.size(1) == w.size(0);
// Typing check
if (x.scalar_type() == at::ScalarType::Half) {
success &= (w.scalar_type() == at::ScalarType::Half) ||
(w.scalar_type() == at::ScalarType::Float);
} else {
success &= x.scalar_type() == w.scalar_type();
}
return success;
}