in candle-pyo3/src/lib.rs [506:648]
fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
let mut indexers: Vec<Indexer> = vec![];
let dims = self.0.shape().dims();
fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
let actual_index = if index < 0 {
dims[current_dim] as isize + index
} else {
index
};
// Check that the index is in range
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
return Err(PyValueError::new_err(format!(
"index out of range for dimension '{current_dim}' with indexer '{index}'"
)));
}
Ok(actual_index as usize)
}
fn extract_indexer(
py_indexer: &Bound<PyAny>,
current_dim: usize,
dims: &[usize],
index_argument_count: usize,
) -> PyResult<(Indexer, usize)> {
if let Ok(index) = py_indexer.extract() {
// Handle a single index e.g. tensor[0] or tensor[-1]
Ok((
Indexer::Index(to_absolute_index(index, current_dim, dims)?),
current_dim + 1,
))
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[current_dim] as isize)?;
Ok((
Indexer::Slice(index.start as usize, index.stop as usize),
current_dim + 1,
))
} else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
// Handle a tensor as indices e.g. tensor[tensor([0,1])]
let t = tensor.0;
if t.rank() != 1 {
return Err(PyTypeError::new_err(
"multi-dimensional tensor indexing is not supported",
));
}
Ok((Indexer::IndexSelect(t), current_dim + 1))
} else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
// Handle a list of indices e.g. tensor[[0,1]]
let mut indexes = vec![];
for item in list.iter() {
let index = item.extract::<i64>()?;
indexes.push(index);
}
Ok((
Indexer::IndexSelect(
Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
),
current_dim + 1,
))
} else if py_indexer.is(&py_indexer.py().Ellipsis()) {
// Handle '...' e.g. tensor[..., 0]
if current_dim > 0 {
return Err(PyTypeError::new_err(
"Ellipsis ('...') can only be used at the start of an indexing operation",
));
}
Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1)))
} else if py_indexer.is_none() {
// Handle None e.g. tensor[None, 0]
Ok((Indexer::Expand, current_dim))
} else {
Err(PyTypeError::new_err(format!(
"unsupported indexer {py_indexer}"
)))
}
}
if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
if not_none_count > dims.len() {
return Err(PyValueError::new_err("provided too many indices"));
}
let mut current_dim = 0;
for item in tuple.iter() {
let (indexer, new_current_dim) =
extract_indexer(&item, current_dim, dims, not_none_count)?;
current_dim = new_current_dim;
indexers.push(indexer);
}
} else {
let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
indexers.push(indexer);
}
let mut x = self.0.clone();
let mut current_dim = 0;
// Apply the indexers
for indexer in indexers.iter() {
x = match indexer {
Indexer::Index(n) => x
.narrow(current_dim, *n, 1)
.map_err(wrap_err)?
.squeeze(current_dim)
.map_err(wrap_err)?,
Indexer::Slice(start, stop) => {
let out = x
.narrow(current_dim, *start, stop.saturating_sub(*start))
.map_err(wrap_err)?;
current_dim += 1;
out
}
Indexer::Ellipsis => {
// Ellipsis is a special case, it means that all remaining dimensions should be
// selected => advance the current_dim to the last dimension we have indexers for
current_dim += dims.len() - (indexers.len() - 1);
x
}
Indexer::Expand => {
// Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim
let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
current_dim += 1;
out
}
Indexer::IndexSelect(indexes) => {
let out = x
.index_select(
&indexes.to_device(x.device()).map_err(wrap_err)?,
current_dim,
)
.map_err(wrap_err)?;
current_dim += 1;
out
}
}
}
Ok(Self(x))
}