in theseus/optimizer/autograd/sparse_autograd.py [0:0]
def backward(ctx, grad_output):
batch_size = grad_output.shape[0]
targs = {"dtype": grad_output.dtype, "device": "cpu"} # grad_output.device}
H = torch.empty(size=(batch_size, ctx.sparse_structure.num_cols), **targs)
AH = torch.empty(size=(batch_size, ctx.sparse_structure.num_rows), **targs)
b_Ax = ctx.b_cpu.clone()
grad_output_cpu = grad_output.cpu()
for i in range(batch_size):
H[i, :] = torch.Tensor(
ctx.cholesky_decompositions[i](grad_output_cpu[i, :])
)
A_i = ctx.sparse_structure.csr_straight(ctx.At_val_cpu[i, :])
AH[i, :] = torch.Tensor(A_i @ H[i, :])
b_Ax[i, :] -= torch.Tensor(A_i @ ctx.x_cpu[i, :])
# now we fill values of a matrix with structure identical to A with
# selected entries from the difference of tensor products:
# b_Ax (X) H - AH (X) x
# NOTE: this row-wise manipulation can be much faster in C++ or Cython
A_col_ind = ctx.sparse_structure.col_ind
A_row_ptr = ctx.sparse_structure.row_ptr
batch_size = grad_output.shape[0]
A_grad = torch.empty(
size=(batch_size, len(A_col_ind)),
device="cpu",
) # return value, A's grad
for r in range(len(A_row_ptr) - 1):
start, end = A_row_ptr[r], A_row_ptr[r + 1]
columns = A_col_ind[start:end] # col indices, for this row
A_grad[:, start:end] = (
b_Ax[:, r].unsqueeze(1) * H[:, columns]
- AH[:, r].unsqueeze(1) * ctx.x_cpu[:, columns]
)
dev = grad_output.device
return A_grad.to(device=dev), AH.to(device=dev), None, None, None