in candle-pyo3/py_src/candle/nn/module.py [0:0]
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None)
:noindex:
.. function:: to(dtype)
:noindex:
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`candle.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`candle.dtype`): the desired floating point dtype of
the parameters and buffers in this module
Returns:
Module: self
"""
device = None
dtype = None
if args:
for arg in args:
# Assuming arg can be a string representing a device or a dtype
if isinstance(arg, str):
lower_arg = str(arg).lower()
if lower_arg.startswith("cuda") or lower_arg == "cpu":
device = lower_arg
else:
dtype = arg
elif isinstance(arg, DType):
dtype = str(arg)
else:
raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args))
if kwargs:
device = kwargs.get("device", device)
dtype = str(kwargs.get("dtype", dtype))
if device:
device = device.lower()
if dtype:
dtype = dtype.lower()
if dtype not in ["f32", "f16", "f64"]:
raise TypeError(
"candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype)
)
def convert(t):
if dtype:
t = self.__cast_tensor(t, dtype)
if device:
t = self.__move_tensor_to_device(t, device)
return t
return self._apply(convert)