in theseus/utils/utils.py [0:0]
def gather_from_rows_cols(matrix: torch.Tensor, rows: torch.Tensor, cols: torch.Tensor):
assert matrix.ndim == 3 and rows.ndim == 2 and rows.ndim == 2
assert matrix.shape[0] == rows.shape[0] and matrix.shape[0] == cols.shape[0]
assert rows.shape[1] == cols.shape[1]
assert rows.min() >= 0 and rows.max() < matrix.shape[1]
assert cols.min() >= 0 and cols.max() < matrix.shape[2]
aux_idx = torch.arange(matrix.shape[0]).unsqueeze(-1).to(matrix.device)
return matrix[aux_idx, rows, cols]