in kfac/python/ops/tensormatch/graph_search.py [0:0]
def filter_subgraph_records(record_list_dict):
"""Remove any matches that correspond to strict subgraphs of other matches."""
# Flatten the records dict to compare records with different parameters.
flat_record_list = [
record for records in record_list_dict.values() for record in records
]
# Compare all pairs of records that share any variables. We perform two
# passes, first marking variables for deletion by adding them to a set and
# then removing all marked variables, in order to avoid traversing
# flat_record_list on every removal while still maintaining record order.
records_by_variable = collections.defaultdict(list)
for record in flat_record_list:
for variable in ensure_sequence(record.params):
records_by_variable[variable].append(record)
records_to_remove = set()
for record in flat_record_list:
for variable in ensure_sequence(record.params):
for other_record in records_by_variable[variable]:
if record.tensor_set < other_record.tensor_set:
records_to_remove.add(record)
flat_record_list = [
record for record in flat_record_list if record not in records_to_remove
]
# Unflatten the records list.
record_list_dict = collections.defaultdict(list)
for record in flat_record_list:
record_list_dict[record.params].append(record)
assert record is not None
return dict(record_list_dict)