in xformers/components/attention/csrc/cpu/sparse_softmax.cpp [128:186]
at::Tensor sparse_softmax_backward_sputnik(
int64_t m,
int64_t n,
const at::Tensor& row_indices,
const at::Tensor& values,
const at::Tensor& grad,
const at::Tensor& row_offsets,
const at::Tensor& column_indices) {
TORCH_CHECK(grad.dim() == 2);
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(row_indices.dim() == 1);
TORCH_CHECK(row_offsets.dim() == 1);
TORCH_CHECK(column_indices.dim() == 1);
TORCH_CHECK(values.size(1) == column_indices.size(0));
TORCH_CHECK(values.size(0) == grad.size(0));
TORCH_CHECK(values.size(1) == grad.size(1));
TORCH_CHECK(!grad.is_cuda(), "grad must be a CPU tensor");
TORCH_CHECK(!row_indices.is_cuda(), "row_indices must be a CPU tensor");
TORCH_CHECK(!values.is_cuda(), "values must be a CPU tensor");
TORCH_CHECK(!row_offsets.is_cuda(), "row_offsets must be a CPU tensor");
TORCH_CHECK(!column_indices.is_cuda(), "column_offsets must be a CPU tensor");
TORCH_CHECK(grad.is_contiguous(), "grad must be a contiguous tensor");
TORCH_CHECK(
row_indices.is_contiguous(), "row_indices must be a contiguous tensor");
TORCH_CHECK(values.is_contiguous(), "values must be a contiguous tensor");
TORCH_CHECK(
row_offsets.is_contiguous(), "row_offsets must be a contiguous tensor");
TORCH_CHECK(
column_indices.is_contiguous(),
"column_offsets must be a contiguous tensor");
TORCH_CHECK(!grad.is_sparse(), "grad must be a dense tensor");
TORCH_CHECK(!row_indices.is_sparse(), "row_indices must be a dense tensor");
TORCH_CHECK(!values.is_sparse(), "values must be a dense tensor");
TORCH_CHECK(!row_offsets.is_sparse(), "row_offsets must be a dense tensor");
TORCH_CHECK(
!column_indices.is_sparse(), "column_offsets must be a dense tensor");
int batch = values.size(0);
int nonzeros = column_indices.size(0);
at::Tensor output = at::empty({batch, nonzeros}, values.options());
SparseSoftmaxBackwardKernel(
m,
n,
grad.data_ptr<float>(),
values.data_ptr<float>(),
row_indices.data_ptr<int>(),
row_offsets.data_ptr<int>(),
column_indices.data_ptr<int>(),
output.data_ptr<float>(),
nonzeros,
batch);
return output;
}