def to()

in candle-pyo3/py_src/candle/nn/module.py [0:0]


    def to(self, *args, **kwargs):
        r"""Moves and/or casts the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None)
           :noindex:

        .. function:: to(dtype)
           :noindex:

        See below for examples.

        .. note::
            This method modifies the module in-place.

        Args:
            device (:class:`candle.device`): the desired device of the parameters
                and buffers in this module
            dtype (:class:`candle.dtype`): the desired floating point dtype of
                the parameters and buffers in this module

        Returns:
            Module: self
        """

        device = None
        dtype = None

        if args:
            for arg in args:
                # Assuming arg can be a string representing a device or a dtype

                if isinstance(arg, str):
                    lower_arg = str(arg).lower()
                    if lower_arg.startswith("cuda") or lower_arg == "cpu":
                        device = lower_arg
                    else:
                        dtype = arg
                elif isinstance(arg, DType):
                    dtype = str(arg)
                else:
                    raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args))

        if kwargs:
            device = kwargs.get("device", device)
            dtype = str(kwargs.get("dtype", dtype))

        if device:
            device = device.lower()

        if dtype:
            dtype = dtype.lower()
            if dtype not in ["f32", "f16", "f64"]:
                raise TypeError(
                    "candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype)
                )

        def convert(t):
            if dtype:
                t = self.__cast_tensor(t, dtype)
            if device:
                t = self.__move_tensor_to_device(t, device)
            return t

        return self._apply(convert)