in tensorflow_probability/python/internal/auto_composite_tensor.py [0:0]
def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None,
prefer_static_value=()):
"""Extract constructor kwargs to reconstruct `obj`."""
# If `obj` inherits its constructor from `AutoCompositeTensor` (which inherits
# its constructor from `object`) return an empty dictionary to avoid
# triggering the error below due to *args and **kwargs in the constructor.
if type(obj).__init__ is AutoCompositeTensor.__init__:
return {}
sig = _cached_signature(type(obj).__init__)
if any(v.kind in (tf_inspect.Parameter.VAR_KEYWORD,
tf_inspect.Parameter.VAR_POSITIONAL)
for v in sig.parameters.values()):
raise ValueError(
'*args and **kwargs are not supported. Found `{}`'.format(sig))
keys = [p for p in sig.parameters if p != 'self' and p not in omit_kwargs]
if limit_to is not None:
keys = [k for k in keys if k in limit_to]
kwargs = {}
not_found = object()
for k in keys:
src1 = getattr(obj, k, not_found)
if src1 is not not_found:
kwargs[k] = src1
else:
src2 = getattr(obj, '_' + k, not_found)
if src2 is not not_found:
kwargs[k] = src2
else:
src3 = getattr(obj, 'parameters', {}).get(k, not_found)
if src3 is not not_found:
kwargs[k] = src3
else:
raise ValueError(
f'Could not determine an appropriate value for field `{k}` in'
f' object `{obj}`. Looked for \n'
f' 1. an attr called `{k}`,\n'
f' 2. an attr called `_{k}`,\n'
f' 3. an entry in `obj.parameters` with key "{k}".')
if k in prefer_static_value and kwargs[k] is not None:
if tf.is_tensor(kwargs[k]):
static_val = tf.get_static_value(kwargs[k])
if static_val is not None:
kwargs[k] = static_val
if isinstance(kwargs[k], (np.ndarray, np.generic)):
# Generally, these are shapes or int, but may be other parameters such as
# `power` for `tfb.PowerTransform`.
kwargs[k] = kwargs[k].tolist()
return kwargs