in tf_agents/utils/nest_utils.py [0:0]
def prune_extra_keys(narrow, wide):
"""Recursively prunes keys from `wide` if they don't appear in `narrow`.
Often used as preprocessing prior to calling `tf.nest.flatten`
or `tf.nest.map_structure`.
This function is more forgiving than the ones in `nest`; if two substructures'
types or structures don't agree, we consider it invalid and `prune_extra_keys`
will return the `wide` substructure as is. Typically, additional checking is
needed: you will also want to use
`nest.assert_same_structure(narrow, prune_extra_keys(narrow, wide))`
to ensure the result of pruning is still a correct structure.
Examples:
```python
wide = [{"a": "a", "b": "b"}]
# Narrows 'wide'
assert prune_extra_keys([{"a": 1}], wide) == [{"a": "a"}]
# 'wide' lacks "c", is considered invalid.
assert prune_extra_keys([{"c": 1}], wide) == wide
# 'wide' contains a different type from 'narrow', is considered invalid
assert prune_extra_keys("scalar", wide) == wide
# 'wide' substructure for key "d" does not match the one in 'narrow' and
# therefore is returned unmodified.
assert (prune_extra_keys({"a": {"b": 1}, "d": None},
{"a": {"b": "b", "c": "c"}, "d": [1, 2]})
== {"a": {"b": "b"}, "d": [1, 2]})
# assert prune_extra_keys((), wide) == ()
# assert prune_extra_keys({"a": ()}, wide) == {"a": ()}
```
Args:
narrow: A nested structure.
wide: A nested structure that may contain dicts with more fields than
`narrow`.
Returns:
A structure with the same nested substructures as `wide`, but with
dicts whose entries are limited to the keys found in the associated
substructures of `narrow`.
In case of substructure or size mismatches, the returned substructures
will be returned as is. Note that ObjectProxy-wrapped objects are
considered equivalent to their non-ObjectProxy types.
"""
# If `narrow` is `()`, then `()` is returned. That is, we narrow any
# object w.r.t. an empty tuple to to an empty tuple. We use `id()`
# here because the emtpy tuple is a singleton in cpython and
# because using "x is ()" or "x == ()" gives syntax warnings for
# numpy arrays.
narrow_raw = (narrow.__wrapped__ if isinstance(narrow, wrapt.ObjectProxy)
else narrow)
if id(narrow_raw) == id(()):
return narrow
if isinstance(wide, wrapt.ObjectProxy):
return type(wide)(prune_extra_keys(narrow, wide.__wrapped__))
wide_raw = (wide.__wrapped__ if isinstance(wide, wrapt.ObjectProxy) else wide)
if ((type(narrow_raw) != type(wide_raw)) # pylint: disable=unidiomatic-typecheck
and not (isinstance(narrow_raw, list) and isinstance(wide_raw, list))
and not (isinstance(narrow_raw, collections_abc.Mapping)
and isinstance(wide_raw, collections_abc.Mapping))):
# We return early if the types are different; but we make some exceptions:
# list subtypes are considered the same (e.g. ListWrapper and list())
# Mapping subtypes are considered the same (e.g. DictWrapper and dict())
# (TupleWrapper subtypes are handled by unwrapping ObjectProxy above).
return wide
if isinstance(narrow, collections_abc.Mapping):
if len(narrow) > len(wide):
# wide lacks a required key from narrow; return early.
return wide
narrow_keys = set(narrow.keys())
wide_keys = set(wide.keys())
if not wide_keys.issuperset(narrow_keys):
# wide lacks a required key from narrow; return early.
return wide
ordered_items = [
(k, prune_extra_keys(v, wide[k]))
for k, v in narrow.items()]
if isinstance(wide, collections.defaultdict):
subset = type(wide)(wide.default_factory, ordered_items)
else:
subset = type(wide)(ordered_items)
return subset
if nest.is_sequence(narrow):
if _is_attrs(wide):
items = [prune_extra_keys(n, w)
for n, w in zip(_attr_items(narrow), _attr_items(wide))]
return type(wide)(*items)
# Not an attrs, so can treat as lists or tuples from here on.
if len(narrow) != len(wide):
# wide's size is different than narrow; return early.
return wide
items = [prune_extra_keys(n, w) for n, w in zip(narrow, wide)]
if _is_namedtuple(wide):
return type(wide)(*items)
elif _is_attrs(wide):
return type(wide)
return type(wide)(items)
# narrow is a leaf, just return wide
return wide