def gather_from_rows_cols()

in theseus/utils/utils.py [0:0]


def gather_from_rows_cols(matrix: torch.Tensor, rows: torch.Tensor, cols: torch.Tensor):
    assert matrix.ndim == 3 and rows.ndim == 2 and rows.ndim == 2
    assert matrix.shape[0] == rows.shape[0] and matrix.shape[0] == cols.shape[0]
    assert rows.shape[1] == cols.shape[1]
    assert rows.min() >= 0 and rows.max() < matrix.shape[1]
    assert cols.min() >= 0 and cols.max() < matrix.shape[2]

    aux_idx = torch.arange(matrix.shape[0]).unsqueeze(-1).to(matrix.device)
    return matrix[aux_idx, rows, cols]