in candle-pyo3/py_src/candle/nn/module.py [0:0]
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~candle.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == "":
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
return destination