include/inplace_abn.h (124 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <tuple>
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include "utils.h"
#ifdef __CUDACC__
#include "cuda_utils.cuh"
#endif
/***********************************************************************************************************************
* Enums
**********************************************************************************************************************/
enum class Activation { LeakyReLU, ELU, Identity };
/***********************************************************************************************************************
* CPU / Cuda methods
**********************************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, at::Tensor> statistics_cpu(
const at::Tensor& x);
std::tuple<at::Tensor, at::Tensor, at::Tensor> statistics_cuda(
const at::Tensor& x);
std::tuple<at::Tensor, at::Tensor, at::Tensor> reduce_statistics_cuda(
const at::Tensor& all_mean,
const at::Tensor& all_var,
const at::Tensor& all_count);
void forward_cpu(
at::Tensor& x,
const at::Tensor& mean,
const at::Tensor& var,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
float eps,
Activation activation,
float activation_param);
void forward_cuda(
at::Tensor& x,
const at::Tensor& mean,
const at::Tensor& var,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
float eps,
Activation activation,
float activation_param);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_cpu(
const at::Tensor& y_act,
const at::Tensor& dy_act,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
float eps,
Activation activation,
float activation_param);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_cuda(
const at::Tensor& y_act,
const at::Tensor& dy_act,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
float eps,
Activation activation,
float activation_param);
void backward_cpu(
const at::Tensor& xhat,
at::Tensor& dy,
const at::Tensor& var,
const at::Tensor& count,
const at::Tensor& sum_dy,
const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight,
float eps);
void backward_cuda(
const at::Tensor& xhat,
at::Tensor& dy,
const at::Tensor& var,
const at::Tensor& count,
const at::Tensor& sum_dy,
const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight,
float eps);
/***********************************************************************************************************************
* Handling of activation functions
**********************************************************************************************************************/
template <typename scalar_t, Activation activation>
struct ActivationFn;
template <typename scalar_t>
struct ActivationFn<scalar_t, Activation::LeakyReLU> {
static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) {
x = (x >= 0) ? x : static_cast<scalar_t>(x * activation_param);
}
static INLINE_HOST_DEVICE void backward(
scalar_t y_act,
scalar_t dy_act,
float activation_param,
scalar_t& y,
scalar_t& dy) {
if (y_act >= 0) {
y = y_act;
dy = dy_act;
} else {
y = static_cast<scalar_t>(y_act / activation_param);
dy = static_cast<scalar_t>(dy_act * activation_param);
}
}
};
template <typename scalar_t>
struct ActivationFn<scalar_t, Activation::ELU> {
static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) {
x = (x >= 0) ? x : static_cast<scalar_t>(activation_param * (::exp(x) - 1));
}
static INLINE_HOST_DEVICE void backward(
scalar_t y_act,
scalar_t dy_act,
float activation_param,
scalar_t& y,
scalar_t& dy) {
if (y_act >= 0) {
y = y_act;
dy = dy_act;
} else {
y = ::log1p(static_cast<scalar_t>(y_act / activation_param));
dy = static_cast<scalar_t>(dy_act * (y_act + activation_param));
}
}
};
template <typename scalar_t>
struct ActivationFn<scalar_t, Activation::Identity> {
static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) {}
static INLINE_HOST_DEVICE void backward(
scalar_t y_act,
scalar_t dy_act,
float activation_param,
scalar_t& y,
scalar_t& dy) {
y = y_act;
dy = dy_act;
}
};