def _get_final_crystal_lattices()

in tensorflow_lattice/python/premade_lib.py [0:0]


def _get_final_crystal_lattices(model_config, prefitting_model_config,
                                prefitting_model, feature_names):
  """Extracts the lattice ensemble structure from the prefitting model."""
  torsions, laplacians = _get_torsions_and_laplacians(
      prefitting_model_config=prefitting_model_config,
      prefitting_model=prefitting_model,
      feature_names=feature_names)

  # Calculate features' importance_score = lambda * laplacians + torsion.
  # Used to allocate slots to useful features with more non-linear interactions.
  num_features = len(feature_names)
  importance_scores = np.array(laplacians) * _LAPLACIAN_WEIGHT_IN_IMPORTANCE
  for feature_0, feature_1 in itertools.combinations(range(num_features), 2):
    importance_scores[feature_0] += torsions[feature_0][feature_1]
    importance_scores[feature_1] += torsions[feature_0][feature_1]

  # Each feature is used at least once, and the remaining slots are distributed
  # proportional to the importance_scores.
  features_uses = [1] * num_features
  total_feature_use = model_config.num_lattices * model_config.lattice_rank
  remaining_uses = total_feature_use - num_features
  remaining_scores = np.sum(importance_scores)
  for feature in np.argsort(-importance_scores):
    added_uses = int(
        round(remaining_uses * importance_scores[feature] / remaining_scores))
    # Each feature cannot be used more than once in a finalized lattice.
    added_uses = min(added_uses, model_config.num_lattices - 1)
    features_uses[feature] += added_uses
    remaining_uses -= added_uses
    remaining_scores -= importance_scores[feature]
  assert np.sum(features_uses) == total_feature_use

  # Add features to add list in round-robin order.
  add_list = []
  for use in range(1, max(features_uses) + 1):
    for feature_index, feature_use in enumerate(features_uses):
      if use <= feature_use:
        add_list.append(feature_index)
  assert len(add_list) == total_feature_use

  # Setup initial lattices that will be optimized by swapping later.
  lattices = [[] for _ in range(model_config.num_lattices)]
  cooccurrence_counts = [[0] * num_features for _ in range(num_features)]
  for feature_to_be_added in add_list:
    # List of pairs of (addition_score, candidate_lattice_to_add_to).
    score_candidates_pairs = []
    for candidate_lattice_to_add_to in range(model_config.num_lattices):
      # addition_score indicates the priority of an addition.
      if len(
          lattices[candidate_lattice_to_add_to]) >= model_config.lattice_rank:
        # going out of bound on the lattice
        addition_score = -2.0
      elif feature_to_be_added in lattices[candidate_lattice_to_add_to]:
        # repeates (fixed repeats later by swapping)
        addition_score = -1.0
      elif not lattices[candidate_lattice_to_add_to]:
        # adding a new lattice roughly has an "average" lattice score
        addition_score = np.mean(torsions) * model_config.lattice_rank**2 / 2
      else:
        # all other cases: change in total discounted torsion after addition.
        addition_score = 0.0
        for other_feature in lattices[candidate_lattice_to_add_to]:
          addition_score += (
              torsions[feature_to_be_added][other_feature] *
              _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE
              **(cooccurrence_counts[feature_to_be_added][other_feature]))

      score_candidates_pairs.append(
          (addition_score, candidate_lattice_to_add_to))

    # Use the highest scoring addition.
    score_candidates_pairs.sort(reverse=True)
    best_candidate_lattice_to_add_to = score_candidates_pairs[0][1]
    for other_feature in lattices[best_candidate_lattice_to_add_to]:
      cooccurrence_counts[feature_to_be_added][other_feature] += 1
      cooccurrence_counts[other_feature][feature_to_be_added] += 1
    lattices[best_candidate_lattice_to_add_to].append(feature_to_be_added)

  # Apply swapping operations to increase within-lattice torsion.
  changed = True
  iteration = 0
  while changed:
    if iteration > _MAX_CRYSTALS_SWAPS:
      logging.info('Crystals algorithm did not fully converge.')
      break
    changed = False
    iteration += 1
    for lattice_0, lattice_1 in itertools.combinations(lattices, 2):
      # For every pair of lattices: lattice_0, lattice_1
      for index_0, index_1 in itertools.product(
          range(len(lattice_0)), range(len(lattice_1))):
        # Consider swapping lattice_0[index_0] with lattice_1[index_1]
        rest_lattice_0 = list(lattice_0)
        rest_lattice_1 = list(lattice_1)
        feature_0 = rest_lattice_0.pop(index_0)
        feature_1 = rest_lattice_1.pop(index_1)
        if feature_0 == feature_1:
          continue

        # Calculate the change in the overall discounted sum of torsion terms.
        added_cooccurrence = set(
            [tuple(sorted((feature_1, other))) for other in rest_lattice_0] +
            [tuple(sorted((feature_0, other))) for other in rest_lattice_1])
        removed_cooccurrence = set(
            [tuple(sorted((feature_0, other))) for other in rest_lattice_0] +
            [tuple(sorted((feature_1, other))) for other in rest_lattice_1])
        wash = added_cooccurrence.intersection(removed_cooccurrence)
        added_cooccurrence = added_cooccurrence.difference(wash)
        removed_cooccurrence = removed_cooccurrence.difference(wash)
        swap_diff_torsion = (
            sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**
                cooccurrence_counts[i][j] for (i, j) in added_cooccurrence) -
            sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**
                (cooccurrence_counts[i][j] - 1)
                for (i, j) in removed_cooccurrence))

        # Swap if a feature is repeated or if the score change is positive.
        if (feature_0 not in lattice_1 and feature_1 not in lattice_0 and
            (lattice_0.count(feature_0) > 1 or lattice_1.count(feature_1) > 1 or
             swap_diff_torsion > 0)):
          for (i, j) in added_cooccurrence:
            cooccurrence_counts[i][j] += 1
            cooccurrence_counts[j][i] += 1
          for (i, j) in removed_cooccurrence:
            cooccurrence_counts[i][j] -= 1
            cooccurrence_counts[j][i] -= 1
          lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],
                                                    lattice_0[index_0])
          changed = True
  # Return the extracted lattice structure.
  return lattices