include/checks.h (43 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. #pragma once #include <ATen/ATen.h> #ifdef TORCH_CHECK #define IABN_CHECK TORCH_CHECK #else #define IABN_CHECK AT_CHECK #endif #define CHECK_CUDA(x) IABN_CHECK((x).is_cuda(), #x " must be a CUDA tensor") #define CHECK_CPU(x) IABN_CHECK(!(x).is_cuda(), #x " must be a CPU tensor") #define CHECK_NOT_HALF(x) \ IABN_CHECK( \ (x).scalar_type() != at::ScalarType::Half, #x " can't have type Half") #define CHECK_SAME_TYPE(x, y) \ IABN_CHECK( \ (x).scalar_type() == (y).scalar_type(), \ #x " and " #y " must have the same scalar type") inline bool have_same_dims(const at::Tensor& x, const at::Tensor& y) { bool success = x.ndimension() == y.ndimension(); for (int64_t dim = 0; dim < x.ndimension(); ++dim) success &= x.size(dim) == y.size(dim); return success; } inline bool is_compatible_weight(const at::Tensor& x, const at::Tensor& w) { // Dimensions check bool success = w.ndimension() == 1; success &= x.size(1) == w.size(0); // Typing check if (x.scalar_type() == at::ScalarType::Half) { success &= (w.scalar_type() == at::ScalarType::Half) || (w.scalar_type() == at::ScalarType::Float); } else { success &= x.scalar_type() == w.scalar_type(); } return success; } inline bool is_compatible_stat(const at::Tensor& x, const at::Tensor& s) { // Dimensions check bool success = s.ndimension() == 1; success &= x.size(1) == s.size(0); // Typing check if (x.scalar_type() == at::ScalarType::Half) { success &= s.scalar_type() == at::ScalarType::Float; } else { success &= x.scalar_type() == s.scalar_type(); } return success; }