in tensorflow_federated/python/core/impl/types/type_analysis.py [0:0]
def check_concrete_instance_of(concrete_type: computation_types.Type,
generic_type: computation_types.Type):
"""Checks whether `concrete_type` is a valid substitution of `generic_type`.
This function determines whether `generic_type`'s type parameters can be
substituted such that it is equivalent to `concrete type`.
Note that passing through argument-position of function type swaps the
variance of abstract types. Argument-position types can be assigned *from*
other instances of the same type, but are not equivalent to it.
Due to this variance issue, only abstract types must include at least one
"defining" usage. "Defining" uses are those which are encased in function
parameter position an odd number of times. These usages must all be
equivalent. Non-defining usages need not compare equal but must be assignable
*from* defining usages.
Args:
concrete_type: A type containing no `computation_types.AbstractType`s to
check against `generic_type`'s shape.
generic_type: A type which may contain `computation_types.AbstractType`s.
Raises:
TypeError: If `concrete_type` is not a valid substitution of `generic_type`.
"""
py_typecheck.check_type(concrete_type, computation_types.Type)
py_typecheck.check_type(generic_type, computation_types.Type)
for t in _preorder_types(concrete_type):
if t.is_abstract():
raise NotConcreteTypeError(concrete_type, t)
type_bindings = {}
non_defining_usages = collections.defaultdict(list)
def _check_helper(generic_type_member: computation_types.Type,
concrete_type_member: computation_types.Type,
defining: bool):
"""Recursive helper function."""
def _raise_structural(mismatch):
raise MismatchedStructureError(concrete_type, generic_type,
concrete_type_member, generic_type_member,
mismatch)
def _both_are(predicate):
if predicate(generic_type_member):
if predicate(concrete_type_member):
return True
else:
_raise_structural('kind')
else:
return False
if generic_type_member.is_abstract():
label = str(generic_type_member.label)
if not defining:
non_defining_usages[label].append(concrete_type_member)
else:
bound_type = type_bindings.get(label)
if bound_type is not None:
if not concrete_type_member.is_equivalent_to(bound_type):
raise MismatchedConcreteTypesError(concrete_type, generic_type,
label, bound_type,
concrete_type_member)
else:
type_bindings[label] = concrete_type_member
elif _both_are(lambda t: t.is_tensor()):
if generic_type_member != concrete_type_member:
_raise_structural('tensor types')
elif _both_are(lambda t: t.is_placement()):
if generic_type_member != concrete_type_member:
_raise_structural('placements')
elif _both_are(lambda t: t.is_struct()):
generic_elements = structure.to_elements(generic_type_member)
concrete_elements = structure.to_elements(concrete_type_member)
if len(generic_elements) != len(concrete_elements):
_raise_structural('length')
for k in range(len(generic_elements)):
if generic_elements[k][0] != concrete_elements[k][0]:
_raise_structural('element names')
_check_helper(generic_elements[k][1], concrete_elements[k][1], defining)
elif _both_are(lambda t: t.is_sequence()):
_check_helper(generic_type_member.element, concrete_type_member.element,
defining)
elif _both_are(lambda t: t.is_function()):
if generic_type_member.parameter is None:
if concrete_type_member.parameter is not None:
_raise_structural('parameter')
else:
_check_helper(generic_type_member.parameter,
concrete_type_member.parameter, not defining)
_check_helper(generic_type_member.result, concrete_type_member.result,
defining)
elif _both_are(lambda t: t.is_federated()):
if generic_type_member.placement != concrete_type_member.placement:
_raise_structural('placement')
if generic_type_member.all_equal != concrete_type_member.all_equal:
_raise_structural('all equal')
_check_helper(generic_type_member.member, concrete_type_member.member,
defining)
else:
raise TypeError(f'Unexpected type kind {generic_type}.')
_check_helper(generic_type, concrete_type, False)
for label, usages in non_defining_usages.items():
bound_type = type_bindings.get(label)
if bound_type is None:
if len(usages) == 1:
# Single-use abstract types can't be wrong.
# Note: we could also add an exception here for cases where every usage
# is equivalent to the first usage. However, that's not currently
# needed since the only intrinsic that doesn't have a defining use is
# GENERIC_ZERO, which has only a single-use type parameter.
pass
else:
raise MissingDefiningUsageError(generic_type, label)
else:
for usage in usages:
if not usage.is_assignable_from(bound_type):
raise UnassignableConcreteTypesError(concrete_type, generic_type,
label, bound_type, usage)