def filter_subgraph_records()

in kfac/python/ops/tensormatch/graph_search.py [0:0]


def filter_subgraph_records(record_list_dict):
  """Remove any matches that correspond to strict subgraphs of other matches."""

  # Flatten the records dict to compare records with different parameters.
  flat_record_list = [
      record for records in record_list_dict.values() for record in records
  ]

  # Compare all pairs of records that share any variables. We perform two
  # passes, first marking variables for deletion by adding them to a set and
  # then removing all marked variables, in order to avoid traversing
  # flat_record_list on every removal while still maintaining record order.
  records_by_variable = collections.defaultdict(list)
  for record in flat_record_list:
    for variable in ensure_sequence(record.params):
      records_by_variable[variable].append(record)
  records_to_remove = set()
  for record in flat_record_list:
    for variable in ensure_sequence(record.params):
      for other_record in records_by_variable[variable]:
        if record.tensor_set < other_record.tensor_set:
          records_to_remove.add(record)
  flat_record_list = [
      record for record in flat_record_list if record not in records_to_remove
  ]

  # Unflatten the records list.
  record_list_dict = collections.defaultdict(list)
  for record in flat_record_list:
    record_list_dict[record.params].append(record)
    assert record is not None
  return dict(record_list_dict)