in lingvo/core/hyperparams.py [0:0]
def Visit(self,
visit_fn: Callable[[str, Any], None],
enter_fn: Optional[Callable[[str, Any], bool]] = None,
exit_fn: Optional[Callable[[str, Any], None]] = None):
"""Recursively visits objects within this Params instance.
Visit can traverse Params, lists, tuples, dataclasses, and namedtuples.
By default, visit_fn is called on any object we don't know how to
traverse into, like an integer or a string. enter_fn and exit_fn are
called on objects we can traverse into, like Params, lists, tuples,
dataclasses, and namedtuples. We call enter_fn before traversing the object,
and exit_fn when we are finished.
If enter_fn returns false, that means we shouldn't traverse into the
object; we call visit_fn on it instead and never call exit_fn.
Keys are of the form::
key.subkey when traversing Params objects
key[1] when traversing lists/tuples
key[subkey] when traversing dataclasses or namedtuples
Lists of (key, value) tuples are treated like namedtuples.
Args:
visit_fn: Called on every object that can't be entered, or when enter_fn
returns false.
enter_fn: Called on every enter-able function. If this function returns
false, we call visit_fn and do not enter the object.
exit_fn: Called after an enter-able object has been traversed.
"""
if not enter_fn:
enter_fn = lambda key, val: True
if not exit_fn:
exit_fn = lambda key, val: None
def _SubKey(key, subkey):
if key:
return f'{key}.{subkey}'
return subkey
def _Visit(key: str, val: Any):
if isinstance(val, Params):
if enter_fn(key, val):
for k, v in val.IterParams():
_Visit(_SubKey(key, k), v)
exit_fn(key, val)
else:
visit_fn(key, val)
elif isinstance(val, dict):
if enter_fn(key, val):
for k, v in val.items():
_Visit(_SubKey(key, k), v)
exit_fn(key, val)
else:
visit_fn(key, val)
elif dataclasses.is_dataclass(val) or _IsNamedTuple(val):
items = val.__dict__.items() if dataclasses.is_dataclass(
val) else val._asdict().items()
if enter_fn(key, val):
for k, v in items:
_Visit(f'{key}[{k}]', v)
exit_fn(key, val)
else:
visit_fn(key, val)
elif (isinstance(val, (list, tuple)) and all(
isinstance(x, tuple) and len(x) == 2 and isinstance(x[0], str) and
isinstance(x[1], Params) for x in val)):
if enter_fn(key, val):
for subtuple in val:
_Visit(f'{key}[{subtuple[0]}]', subtuple[1])
exit_fn(key, val)
elif isinstance(val, list) or isinstance(val, range) or isinstance(
val, tuple):
if enter_fn(key, val):
for i, v in enumerate(val):
_Visit(f'{key}[{i}]', v)
exit_fn(key, val)
else:
visit_fn(key, val)
else:
visit_fn(key, val)
_Visit('', self)