def prune_extra_keys()

in tf_agents/utils/nest_utils.py [0:0]


def prune_extra_keys(narrow, wide):
  """Recursively prunes keys from `wide` if they don't appear in `narrow`.

  Often used as preprocessing prior to calling `tf.nest.flatten`
  or `tf.nest.map_structure`.

  This function is more forgiving than the ones in `nest`; if two substructures'
  types or structures don't agree, we consider it invalid and `prune_extra_keys`
  will return the `wide` substructure as is.  Typically, additional checking is
  needed: you will also want to use
  `nest.assert_same_structure(narrow, prune_extra_keys(narrow, wide))`
  to ensure the result of pruning is still a correct structure.

  Examples:
  ```python
  wide = [{"a": "a", "b": "b"}]
  # Narrows 'wide'
  assert prune_extra_keys([{"a": 1}], wide) == [{"a": "a"}]
  # 'wide' lacks "c", is considered invalid.
  assert prune_extra_keys([{"c": 1}], wide) == wide
  # 'wide' contains a different type from 'narrow', is considered invalid
  assert prune_extra_keys("scalar", wide) == wide
  # 'wide' substructure for key "d" does not match the one in 'narrow' and
  # therefore is returned unmodified.
  assert (prune_extra_keys({"a": {"b": 1}, "d": None},
                           {"a": {"b": "b", "c": "c"}, "d": [1, 2]})
          == {"a": {"b": "b"}, "d": [1, 2]})
  # assert prune_extra_keys((), wide) == ()
  # assert prune_extra_keys({"a": ()}, wide) == {"a": ()}
  ```

  Args:
    narrow: A nested structure.
    wide: A nested structure that may contain dicts with more fields than
      `narrow`.

  Returns:
    A structure with the same nested substructures as `wide`, but with
    dicts whose entries are limited to the keys found in the associated
    substructures of `narrow`.

    In case of substructure or size mismatches, the returned substructures
    will be returned as is.  Note that ObjectProxy-wrapped objects are
    considered equivalent to their non-ObjectProxy types.
  """
  #  If `narrow` is `()`, then `()` is returned.  That is, we narrow any
  #  object w.r.t. an empty tuple to to an empty tuple.  We use `id()`
  #  here because the emtpy tuple is a singleton in cpython and
  #  because using "x is ()" or "x == ()" gives syntax warnings for
  #  numpy arrays.
  narrow_raw = (narrow.__wrapped__ if isinstance(narrow, wrapt.ObjectProxy)
                else narrow)

  if id(narrow_raw) == id(()):
    return narrow

  if isinstance(wide, wrapt.ObjectProxy):
    return type(wide)(prune_extra_keys(narrow, wide.__wrapped__))

  wide_raw = (wide.__wrapped__ if isinstance(wide, wrapt.ObjectProxy) else wide)

  if ((type(narrow_raw) != type(wide_raw))  # pylint: disable=unidiomatic-typecheck
      and not (isinstance(narrow_raw, list) and isinstance(wide_raw, list))
      and not (isinstance(narrow_raw, collections_abc.Mapping)
               and isinstance(wide_raw, collections_abc.Mapping))):
    # We return early if the types are different; but we make some exceptions:
    #  list subtypes are considered the same (e.g. ListWrapper and list())
    #  Mapping subtypes are considered the same (e.g. DictWrapper and dict())
    #  (TupleWrapper subtypes are handled by unwrapping ObjectProxy above).
    return wide

  if isinstance(narrow, collections_abc.Mapping):
    if len(narrow) > len(wide):
      # wide lacks a required key from narrow; return early.
      return wide

    narrow_keys = set(narrow.keys())
    wide_keys = set(wide.keys())
    if not wide_keys.issuperset(narrow_keys):
      # wide lacks a required key from narrow; return early.
      return wide
    ordered_items = [
        (k, prune_extra_keys(v, wide[k]))
        for k, v in narrow.items()]
    if isinstance(wide, collections.defaultdict):
      subset = type(wide)(wide.default_factory, ordered_items)
    else:
      subset = type(wide)(ordered_items)
    return subset

  if nest.is_sequence(narrow):
    if _is_attrs(wide):
      items = [prune_extra_keys(n, w)
               for n, w in zip(_attr_items(narrow), _attr_items(wide))]
      return type(wide)(*items)

    # Not an attrs, so can treat as lists or tuples from here on.
    if len(narrow) != len(wide):
      # wide's size is different than narrow; return early.
      return wide

    items = [prune_extra_keys(n, w) for n, w in zip(narrow, wide)]
    if _is_namedtuple(wide):
      return type(wide)(*items)
    elif _is_attrs(wide):
      return type(wide)
    return type(wide)(items)

  # narrow is a leaf, just return wide
  return wide