in theseus/optimizer/autograd/sparse_autograd.py [0:0]
def forward(ctx, *args, **kwargs):
At_val: torch.Tensor = args[0]
b: torch.Tensor = args[1]
sparse_structure: SparseStructure = args[2]
symbolic_decomposition: CholeskyDecomposition = args[3]
damping: float = args[4]
At_val_cpu = At_val.cpu()
b_cpu = b.cpu()
batch_size = At_val.shape[0]
targs = {"dtype": At_val.dtype, "device": "cpu"}
x_cpu = torch.empty(size=(batch_size, sparse_structure.num_cols), **targs)
cholesky_decompositions = []
for i in range(batch_size):
# compute decomposition from symbolic decomposition
At_i = sparse_structure.csc_transpose(At_val_cpu[i, :])
cholesky_decomposition = symbolic_decomposition.cholesky_AAt(At_i, damping)
# solve
Atb_i = At_i @ b_cpu[i, :]
x_cpu[i, :] = torch.Tensor(cholesky_decomposition(Atb_i))
cholesky_decompositions.append(cholesky_decomposition)
ctx.b_cpu = b_cpu
ctx.x_cpu = x_cpu
ctx.At_val_cpu = At_val_cpu
ctx.sparse_structure = sparse_structure
ctx.cholesky_decompositions = cholesky_decompositions
return x_cpu.to(device=At_val.device)