in candle-pyo3/src/lib.rs [957:1056]
fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
let mut device: Option<PyDevice> = None;
let mut dtype: Option<PyDType> = None;
let mut other: Option<PyTensor> = None;
fn handle_duplicates<T>(
opt: &mut Option<T>,
extraction_result: PyResult<T>,
err_msg: &'static str,
) -> PyResult<()> {
if let Ok(successful_extraction) = extraction_result {
if opt.is_some() {
return Err(PyValueError::new_err(err_msg));
}
*opt = Some(successful_extraction);
}
Ok(())
}
//handle args
for arg in args.iter() {
if arg.extract::<PyDevice>().is_ok() {
handle_duplicates(
&mut device,
arg.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
} else if arg.extract::<PyDType>().is_ok() {
handle_duplicates(
&mut dtype,
arg.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
} else if arg.extract::<PyTensor>().is_ok() {
handle_duplicates(
&mut other,
arg.extract::<PyTensor>(),
"cannot specify multiple output tensors",
)?;
} else {
return Err(PyTypeError::new_err(format!(
"unsupported argument type `{:#?}`",
arg.get_type().name()
)));
}
}
if let Some(kwargs) = kwargs {
if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates(
&mut dtype,
any.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
}
if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates(
&mut device,
any.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
}
if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates(
&mut other,
any.extract::<PyTensor>(),
"cannot specify multiple output tensors",
)?;
}
}
if let Some(other) = other {
if device.is_some() {
return Err(PyValueError::new_err(
"cannot specify both an output tensor and a device",
));
}
if dtype.is_some() {
return Err(PyValueError::new_err(
"cannot specify both an output tensor and a dtype",
));
}
dtype = Some(other.dtype());
device = Some(PyDevice::from_device(other.0.device()));
}
let result = match (device, dtype) {
(Some(device), Some(dtype)) => self
.0
.to_device(&device.as_device()?)
.map_err(wrap_err)?
.to_dtype(dtype.0)
.map_err(wrap_err)?,
(Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
(None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
(None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")),
};
Ok(PyTensor(result))
}