def Visit()

in lingvo/core/hyperparams.py [0:0]


  def Visit(self,
            visit_fn: Callable[[str, Any], None],
            enter_fn: Optional[Callable[[str, Any], bool]] = None,
            exit_fn: Optional[Callable[[str, Any], None]] = None):
    """Recursively visits objects within this Params instance.

    Visit can traverse Params, lists, tuples, dataclasses, and namedtuples.

    By default, visit_fn is called on any object we don't know how to
    traverse into, like an integer or a string. enter_fn and exit_fn are
    called on objects we can traverse into, like Params, lists, tuples,
    dataclasses, and namedtuples. We call enter_fn before traversing the object,
    and exit_fn when we are finished.

    If enter_fn returns false, that means we shouldn't traverse into the
    object; we call visit_fn on it instead and never call exit_fn.

    Keys are of the form::

      key.subkey when traversing Params objects
      key[1] when traversing lists/tuples
      key[subkey] when traversing dataclasses or namedtuples

    Lists of (key, value) tuples are treated like namedtuples.

    Args:
      visit_fn: Called on every object that can't be entered, or when enter_fn
        returns false.
      enter_fn: Called on every enter-able function. If this function returns
        false, we call visit_fn and do not enter the object.
      exit_fn: Called after an enter-able object has been traversed.
    """
    if not enter_fn:
      enter_fn = lambda key, val: True
    if not exit_fn:
      exit_fn = lambda key, val: None

    def _SubKey(key, subkey):
      if key:
        return f'{key}.{subkey}'
      return subkey

    def _Visit(key: str, val: Any):
      if isinstance(val, Params):
        if enter_fn(key, val):
          for k, v in val.IterParams():
            _Visit(_SubKey(key, k), v)
          exit_fn(key, val)
        else:
          visit_fn(key, val)
      elif isinstance(val, dict):
        if enter_fn(key, val):
          for k, v in val.items():
            _Visit(_SubKey(key, k), v)
          exit_fn(key, val)
        else:
          visit_fn(key, val)
      elif dataclasses.is_dataclass(val) or _IsNamedTuple(val):
        items = val.__dict__.items() if dataclasses.is_dataclass(
            val) else val._asdict().items()
        if enter_fn(key, val):
          for k, v in items:
            _Visit(f'{key}[{k}]', v)
          exit_fn(key, val)
        else:
          visit_fn(key, val)
      elif (isinstance(val, (list, tuple)) and all(
          isinstance(x, tuple) and len(x) == 2 and isinstance(x[0], str) and
          isinstance(x[1], Params) for x in val)):
        if enter_fn(key, val):
          for subtuple in val:
            _Visit(f'{key}[{subtuple[0]}]', subtuple[1])
          exit_fn(key, val)
      elif isinstance(val, list) or isinstance(val, range) or isinstance(
          val, tuple):
        if enter_fn(key, val):
          for i, v in enumerate(val):
            _Visit(f'{key}[{i}]', v)
          exit_fn(key, val)
        else:
          visit_fn(key, val)
      else:
        visit_fn(key, val)

    _Visit('', self)