include/dispatch.h (48 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#define NORMAL_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
using prmscalar_t = type; \
return __VA_ARGS__(); \
}
#define HALF_CASE_TYPE(enum_type, x_type, w_scalar_type, ...) \
case enum_type: { \
using scalar_t = x_type; \
if (w_scalar_type == at::ScalarType::Half) { \
using prmscalar_t = at::Half; \
return __VA_ARGS__(); \
} else if (w_scalar_type == at::ScalarType::Float) { \
using prmscalar_t = float; \
return __VA_ARGS__(); \
} else { \
AT_ERROR("Unsupported type combination '" #enum_type \
"', '" #w_scalar_type "'"); \
} \
}
#define DOUBLE_DISPATCH(XTYPE, WTYPE, NAME, ...) \
[&] { \
const auto& x_type = XTYPE; \
const auto& w_type = WTYPE; \
switch (x_type) { \
NORMAL_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
NORMAL_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
HALF_CASE_TYPE(at::ScalarType::Half, at::Half, w_type, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(x_type), "'"); \
} \
}()
#ifdef WITH_CUDA
#define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \
if ((REF_TENSOR).is_cuda()) { \
return METHOD##_cuda(__VA_ARGS__); \
} else { \
return METHOD##_cpu(__VA_ARGS__); \
}
#else
#define CUDA_DISPATCH(REF_TENSOR, METHOD, ...) \
if ((REF_TENSOR).is_cuda()) { \
AT_ERROR("CUDA support was not enabled at compile time"); \
} else { \
return METHOD##_cpu(__VA_ARGS__); \
}
#endif