include/utils.h (20 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <cmath>
#include <tuple>
#include <ATen/ATen.h>
/***********************************************************************************************************************
* General defines
**********************************************************************************************************************/
#ifdef __CUDACC__
#define HOST_DEVICE __host__ __device__
#define INLINE_HOST_DEVICE __host__ __device__ __forceinline__
#else
// CPU versions
#define HOST_DEVICE
#define INLINE_HOST_DEVICE inline
#endif // #ifdef __CUDACC__
/***********************************************************************************************************************
* Utility functions
**********************************************************************************************************************/
at::Tensor normalize_shape(const at::Tensor& x);
template <typename scalar_t, int64_t dim>
static at::TensorAccessor<scalar_t, dim> accessor_or_dummy(
const c10::optional<at::Tensor>& t) {
if (!t.has_value()) {
return at::TensorAccessor<scalar_t, dim>(nullptr, nullptr, nullptr);
}
return t.value().accessor<scalar_t, dim>();
}