fn to()

in candle-pyo3/src/lib.rs [957:1056]


    fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
        let mut device: Option<PyDevice> = None;
        let mut dtype: Option<PyDType> = None;
        let mut other: Option<PyTensor> = None;

        fn handle_duplicates<T>(
            opt: &mut Option<T>,
            extraction_result: PyResult<T>,
            err_msg: &'static str,
        ) -> PyResult<()> {
            if let Ok(successful_extraction) = extraction_result {
                if opt.is_some() {
                    return Err(PyValueError::new_err(err_msg));
                }
                *opt = Some(successful_extraction);
            }
            Ok(())
        }

        //handle args
        for arg in args.iter() {
            if arg.extract::<PyDevice>().is_ok() {
                handle_duplicates(
                    &mut device,
                    arg.extract::<PyDevice>(),
                    "cannot specify multiple devices",
                )?;
            } else if arg.extract::<PyDType>().is_ok() {
                handle_duplicates(
                    &mut dtype,
                    arg.extract::<PyDType>(),
                    "cannot specify multiple dtypes",
                )?;
            } else if arg.extract::<PyTensor>().is_ok() {
                handle_duplicates(
                    &mut other,
                    arg.extract::<PyTensor>(),
                    "cannot specify multiple output tensors",
                )?;
            } else {
                return Err(PyTypeError::new_err(format!(
                    "unsupported argument type `{:#?}`",
                    arg.get_type().name()
                )));
            }
        }

        if let Some(kwargs) = kwargs {
            if let Ok(Some(any)) = kwargs.get_item("dtype") {
                handle_duplicates(
                    &mut dtype,
                    any.extract::<PyDType>(),
                    "cannot specify multiple dtypes",
                )?;
            }
            if let Ok(Some(any)) = kwargs.get_item("device") {
                handle_duplicates(
                    &mut device,
                    any.extract::<PyDevice>(),
                    "cannot specify multiple devices",
                )?;
            }
            if let Ok(Some(any)) = kwargs.get_item("other") {
                handle_duplicates(
                    &mut other,
                    any.extract::<PyTensor>(),
                    "cannot specify multiple output tensors",
                )?;
            }
        }

        if let Some(other) = other {
            if device.is_some() {
                return Err(PyValueError::new_err(
                    "cannot specify both an output tensor and a device",
                ));
            }
            if dtype.is_some() {
                return Err(PyValueError::new_err(
                    "cannot specify both an output tensor and a dtype",
                ));
            }
            dtype = Some(other.dtype());
            device = Some(PyDevice::from_device(other.0.device()));
        }

        let result = match (device, dtype) {
            (Some(device), Some(dtype)) => self
                .0
                .to_device(&device.as_device()?)
                .map_err(wrap_err)?
                .to_dtype(dtype.0)
                .map_err(wrap_err)?,
            (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
            (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
            (None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")),
        };

        Ok(PyTensor(result))
    }