std::vector CusolverLUSolver::factor()

in theseus/extlib/cusolver_lu_solver.cpp [241:279]


std::vector<int> CusolverLUSolver::factor(const torch::Tensor& A_val) {

    TORCH_CHECK(A_val.device().is_cuda());
    TORCH_CHECK(A_val.dim() == 2);

    // we ideally would like to check "<=" and support irregular (smaller)
    // batch sizes, but (disappointingly) cuda fails unless "==" holds
    TORCH_CHECK(A_val.size(0) == batchSize);
    TORCH_CHECK(A_val.size(1) == nnz);

    factorId++;
    factoredBatchSize = A_val.size(0);

    at::Tensor A_val_array_cpu = torch::empty(factoredBatchSize * sizeof(double*), torch::TensorOptions(torch::kByte));
    double* pA_val = A_val.data_ptr<double>();
    double** pA_val_array_cpu = (double**)A_val_array_cpu.data_ptr<uint8_t>();
    for(int i = 0; i < factoredBatchSize; i++) {
        pA_val_array_cpu[i] = pA_val + nnz * i;
    }
    at::Tensor A_val_array = A_val_array_cpu.cuda();

    CUSOLVER_CHECK(cusolverRfBatchResetValues(factoredBatchSize,
                                              numRows, nnz,
                                              A_rowPtr.data_ptr<int>(), A_colInd.data_ptr<int>(), (double**)A_val_array.data_ptr<uint8_t>(),
                                              P.data_ptr<int>(), Q.data_ptr<int>(),
                                              cusolverRfH));

    CUSOLVER_CHECK(cusolverRfBatchRefactor(cusolverRfH));

    std::vector<int> singularityPositions(factoredBatchSize);
    CUSOLVER_CHECK(cusolverRfBatchZeroPivot(cusolverRfH, singularityPositions.data()));
    for(int i = 0; i < factoredBatchSize; i++) {
        if (singularityPositions[i] >= 0){
            fprintf(stderr, "Error: A[%d] is not invertible, singularity=%d\n", i, singularityPositions[i]);
        }
    }

    return singularityPositions;
}