Tensor NestedTensor_batch_norm()

in nestedtensor/csrc/autograd_functions.cpp [85:218]


Tensor NestedTensor_batch_norm(
    const Tensor& input,
    const c10::optional<Tensor>& weight /* optional */,
    const c10::optional<Tensor>& bias /* optional */,
    const c10::optional<Tensor>& running_mean /* optional */,
    const c10::optional<Tensor>& running_var /* optional */,
    bool training,
    double momentum,
    double eps,
    bool cudnn_enabled) {
  auto opt_sizes = get_nested_tensor_impl(input)->opt_sizes();
  TORCH_CHECK(opt_sizes[1], "batch norm requires regular second dimension.");
  TORCH_CHECK(!training, "batch norm does not support training.");
  int64_t n_input = *opt_sizes[1];
  TORCH_CHECK(running_mean, "running_mean must be defined in evaluation mode");
  TORCH_CHECK(running_var, "running_var must be defined in evaluation mode");
  if (weight) {
    check_dims_match_num_input_features("weight", n_input, get_numel(*weight));
  }
  if (bias) {
    check_dims_match_num_input_features("bias", n_input, get_numel(*bias));
  }

  at::Tensor mean = *running_mean;
  at::Tensor var = *running_var;
#ifdef WITH_CUDA
  if (weight &&
      bias &&
      (is_nested_tensor_impl(input)) &&
      (!is_nested_tensor_impl(mean)) &&
      (!is_nested_tensor_impl(var)) &&
      (!is_nested_tensor_impl(*bias)) &&
      (!is_nested_tensor_impl(*weight)) &&
      (input.dtype()   == torch::kHalf) &&
      (mean.dtype()    == torch::kHalf) &&
      (var.dtype()     == torch::kHalf) &&
      (bias->dtype()   == torch::kHalf) &&
      (weight->dtype() == torch::kHalf) &&
      get_is_cuda(input)
  )
  {
    // Custom CUDA Half implementation.
    mean = mean.contiguous();
    Tensor bias_cont = (*bias).contiguous();
    Tensor weight_cont = (*weight).contiguous();
    Tensor running_var_cont = (*running_var).contiguous();

    c10::Half* mean_ptr = mean.data_ptr<c10::Half>();
    c10::Half* bias_ptr = bias_cont.data_ptr<c10::Half>();
    c10::Half* weight_ptr = weight_cont.data_ptr<c10::Half>();
    c10::Half* running_var_ptr = running_var_cont.data_ptr<c10::Half>();

    if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
      Tensor input_buffer = get_buffer(input);
      int64_t num_channel = weight_cont.size(0);
      at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
      nested_tensor::cuda::batchnorm_inference_channels_last_kernelLauncher(
          input_buffer.data_ptr<c10::Half>(),
          mean_ptr,
          running_var_ptr,
          c10::Half((float)(eps)),
          weight_ptr,
          bias_ptr,
          input_buffer.data_ptr<c10::Half>(),
          num_channel,
          input_buffer.numel(),
          defaultStream);
      input_buffer = input_buffer.view(-1);
      return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(input), get_efficient_nested_stride(input));
    }
  
    Tensor output = input;
    output = NestedTensor_contiguous(output);
    Tensor input_buffer = get_buffer(output);
    // Tensor output_buffer = input_buffer.clone();
  
    auto self_opt_sizes = get_opt_sizes(input);
  
    Tensor nt_sizes_ =
        get_efficient_nested_size(input).sizes(); // .to(torch::kInt32);
    Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
    Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
    Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
    int64_t* nt_sizes_all_ptr = nt_sizes_all.data_ptr<int64_t>();
    at::Tensor numbers_t = at::empty({1 + (nt_sizes_all.size(0) * *self_opt_sizes[1])}, torch::kInt64);
    int64_t* numbers_t_ptr = numbers_t.data_ptr<int64_t>();
    numbers_t_ptr[0] = 0;
    int64_t index = 1;
    for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
      for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
        numbers_t_ptr[index] = (numbers_t_ptr[index - 1] + nt_sizes_all_ptr[i]);
        index++;
      }
    }
    Tensor nt_sizes = numbers_t.to(at::Device(kCUDA), torch::kInt32, true, true);
  
    at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
    nested_tensor::cuda::batchnorm_inference_kernelLauncher(
        input_buffer.data_ptr<c10::Half>(),
        mean_ptr,
        running_var_ptr,
        c10::Half((float)(eps)),
        weight_ptr,
        bias_ptr,
        input_buffer.data_ptr<c10::Half>(),
        // output_buffer.data_ptr<c10::Half>(),
        (int)(*self_opt_sizes[0]),
        (int)(weight_cont.size(0)),
        (int)(*self_opt_sizes[0] *
              *self_opt_sizes[1] *
              *self_opt_sizes[2] *
              *self_opt_sizes[3]),
        nt_sizes.data_ptr<int>(),
        defaultStream
        );
    return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(output), get_efficient_nested_stride(output));
  }
#endif
  auto scalar_shape = make_scalar_shape(get_dim(input), n_input);

  at::Tensor invstd = 1 / at::sqrt(*running_var + eps);

  Tensor output = input;
  output = output - mean.reshape(IntArrayRef(scalar_shape));
  output = output * invstd.reshape(IntArrayRef(scalar_shape));

  if (weight) {
    output = output * weight->reshape(IntArrayRef(scalar_shape));
  }
  if (bias) {
    output = output + bias->reshape(IntArrayRef(scalar_shape));
  }
  return output;
}