fn __getitem__()

in candle-pyo3/src/lib.rs [506:648]


    fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
        let mut indexers: Vec<Indexer> = vec![];
        let dims = self.0.shape().dims();

        fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
            // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
            let actual_index = if index < 0 {
                dims[current_dim] as isize + index
            } else {
                index
            };

            // Check that the index is in range
            if actual_index < 0 || actual_index >= dims[current_dim] as isize {
                return Err(PyValueError::new_err(format!(
                    "index out of range for dimension '{current_dim}' with indexer '{index}'"
                )));
            }
            Ok(actual_index as usize)
        }

        fn extract_indexer(
            py_indexer: &Bound<PyAny>,
            current_dim: usize,
            dims: &[usize],
            index_argument_count: usize,
        ) -> PyResult<(Indexer, usize)> {
            if let Ok(index) = py_indexer.extract() {
                // Handle a single index e.g. tensor[0] or tensor[-1]
                Ok((
                    Indexer::Index(to_absolute_index(index, current_dim, dims)?),
                    current_dim + 1,
                ))
            } else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
                // Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
                let index = slice.indices(dims[current_dim] as isize)?;
                Ok((
                    Indexer::Slice(index.start as usize, index.stop as usize),
                    current_dim + 1,
                ))
            } else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
                // Handle a tensor as indices e.g. tensor[tensor([0,1])]
                let t = tensor.0;
                if t.rank() != 1 {
                    return Err(PyTypeError::new_err(
                        "multi-dimensional tensor indexing is not supported",
                    ));
                }
                Ok((Indexer::IndexSelect(t), current_dim + 1))
            } else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
                // Handle a list of indices e.g. tensor[[0,1]]
                let mut indexes = vec![];
                for item in list.iter() {
                    let index = item.extract::<i64>()?;
                    indexes.push(index);
                }
                Ok((
                    Indexer::IndexSelect(
                        Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
                    ),
                    current_dim + 1,
                ))
            } else if py_indexer.is(&py_indexer.py().Ellipsis()) {
                // Handle '...' e.g. tensor[..., 0]
                if current_dim > 0 {
                    return Err(PyTypeError::new_err(
                        "Ellipsis ('...') can only be used at the start of an indexing operation",
                    ));
                }
                Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1)))
            } else if py_indexer.is_none() {
                // Handle None e.g. tensor[None, 0]
                Ok((Indexer::Expand, current_dim))
            } else {
                Err(PyTypeError::new_err(format!(
                    "unsupported indexer {py_indexer}"
                )))
            }
        }

        if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
            let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();

            if not_none_count > dims.len() {
                return Err(PyValueError::new_err("provided too many indices"));
            }

            let mut current_dim = 0;
            for item in tuple.iter() {
                let (indexer, new_current_dim) =
                    extract_indexer(&item, current_dim, dims, not_none_count)?;
                current_dim = new_current_dim;
                indexers.push(indexer);
            }
        } else {
            let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
            indexers.push(indexer);
        }

        let mut x = self.0.clone();
        let mut current_dim = 0;
        // Apply the indexers
        for indexer in indexers.iter() {
            x = match indexer {
                Indexer::Index(n) => x
                    .narrow(current_dim, *n, 1)
                    .map_err(wrap_err)?
                    .squeeze(current_dim)
                    .map_err(wrap_err)?,
                Indexer::Slice(start, stop) => {
                    let out = x
                        .narrow(current_dim, *start, stop.saturating_sub(*start))
                        .map_err(wrap_err)?;
                    current_dim += 1;
                    out
                }
                Indexer::Ellipsis => {
                    // Ellipsis is a special case, it means that all remaining dimensions should be
                    // selected => advance the current_dim to the last dimension we have indexers for
                    current_dim += dims.len() - (indexers.len() - 1);
                    x
                }
                Indexer::Expand => {
                    // Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim
                    let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
                    current_dim += 1;
                    out
                }
                Indexer::IndexSelect(indexes) => {
                    let out = x
                        .index_select(
                            &indexes.to_device(x.device()).map_err(wrap_err)?,
                            current_dim,
                        )
                        .map_err(wrap_err)?;
                    current_dim += 1;
                    out
                }
            }
        }

        Ok(Self(x))
    }