in tensorflow_lattice/python/premade_lib.py [0:0]
def _get_final_crystal_lattices(model_config, prefitting_model_config,
prefitting_model, feature_names):
"""Extracts the lattice ensemble structure from the prefitting model."""
torsions, laplacians = _get_torsions_and_laplacians(
prefitting_model_config=prefitting_model_config,
prefitting_model=prefitting_model,
feature_names=feature_names)
# Calculate features' importance_score = lambda * laplacians + torsion.
# Used to allocate slots to useful features with more non-linear interactions.
num_features = len(feature_names)
importance_scores = np.array(laplacians) * _LAPLACIAN_WEIGHT_IN_IMPORTANCE
for feature_0, feature_1 in itertools.combinations(range(num_features), 2):
importance_scores[feature_0] += torsions[feature_0][feature_1]
importance_scores[feature_1] += torsions[feature_0][feature_1]
# Each feature is used at least once, and the remaining slots are distributed
# proportional to the importance_scores.
features_uses = [1] * num_features
total_feature_use = model_config.num_lattices * model_config.lattice_rank
remaining_uses = total_feature_use - num_features
remaining_scores = np.sum(importance_scores)
for feature in np.argsort(-importance_scores):
added_uses = int(
round(remaining_uses * importance_scores[feature] / remaining_scores))
# Each feature cannot be used more than once in a finalized lattice.
added_uses = min(added_uses, model_config.num_lattices - 1)
features_uses[feature] += added_uses
remaining_uses -= added_uses
remaining_scores -= importance_scores[feature]
assert np.sum(features_uses) == total_feature_use
# Add features to add list in round-robin order.
add_list = []
for use in range(1, max(features_uses) + 1):
for feature_index, feature_use in enumerate(features_uses):
if use <= feature_use:
add_list.append(feature_index)
assert len(add_list) == total_feature_use
# Setup initial lattices that will be optimized by swapping later.
lattices = [[] for _ in range(model_config.num_lattices)]
cooccurrence_counts = [[0] * num_features for _ in range(num_features)]
for feature_to_be_added in add_list:
# List of pairs of (addition_score, candidate_lattice_to_add_to).
score_candidates_pairs = []
for candidate_lattice_to_add_to in range(model_config.num_lattices):
# addition_score indicates the priority of an addition.
if len(
lattices[candidate_lattice_to_add_to]) >= model_config.lattice_rank:
# going out of bound on the lattice
addition_score = -2.0
elif feature_to_be_added in lattices[candidate_lattice_to_add_to]:
# repeates (fixed repeats later by swapping)
addition_score = -1.0
elif not lattices[candidate_lattice_to_add_to]:
# adding a new lattice roughly has an "average" lattice score
addition_score = np.mean(torsions) * model_config.lattice_rank**2 / 2
else:
# all other cases: change in total discounted torsion after addition.
addition_score = 0.0
for other_feature in lattices[candidate_lattice_to_add_to]:
addition_score += (
torsions[feature_to_be_added][other_feature] *
_REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE
**(cooccurrence_counts[feature_to_be_added][other_feature]))
score_candidates_pairs.append(
(addition_score, candidate_lattice_to_add_to))
# Use the highest scoring addition.
score_candidates_pairs.sort(reverse=True)
best_candidate_lattice_to_add_to = score_candidates_pairs[0][1]
for other_feature in lattices[best_candidate_lattice_to_add_to]:
cooccurrence_counts[feature_to_be_added][other_feature] += 1
cooccurrence_counts[other_feature][feature_to_be_added] += 1
lattices[best_candidate_lattice_to_add_to].append(feature_to_be_added)
# Apply swapping operations to increase within-lattice torsion.
changed = True
iteration = 0
while changed:
if iteration > _MAX_CRYSTALS_SWAPS:
logging.info('Crystals algorithm did not fully converge.')
break
changed = False
iteration += 1
for lattice_0, lattice_1 in itertools.combinations(lattices, 2):
# For every pair of lattices: lattice_0, lattice_1
for index_0, index_1 in itertools.product(
range(len(lattice_0)), range(len(lattice_1))):
# Consider swapping lattice_0[index_0] with lattice_1[index_1]
rest_lattice_0 = list(lattice_0)
rest_lattice_1 = list(lattice_1)
feature_0 = rest_lattice_0.pop(index_0)
feature_1 = rest_lattice_1.pop(index_1)
if feature_0 == feature_1:
continue
# Calculate the change in the overall discounted sum of torsion terms.
added_cooccurrence = set(
[tuple(sorted((feature_1, other))) for other in rest_lattice_0] +
[tuple(sorted((feature_0, other))) for other in rest_lattice_1])
removed_cooccurrence = set(
[tuple(sorted((feature_0, other))) for other in rest_lattice_0] +
[tuple(sorted((feature_1, other))) for other in rest_lattice_1])
wash = added_cooccurrence.intersection(removed_cooccurrence)
added_cooccurrence = added_cooccurrence.difference(wash)
removed_cooccurrence = removed_cooccurrence.difference(wash)
swap_diff_torsion = (
sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**
cooccurrence_counts[i][j] for (i, j) in added_cooccurrence) -
sum(torsions[i][j] * _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE**
(cooccurrence_counts[i][j] - 1)
for (i, j) in removed_cooccurrence))
# Swap if a feature is repeated or if the score change is positive.
if (feature_0 not in lattice_1 and feature_1 not in lattice_0 and
(lattice_0.count(feature_0) > 1 or lattice_1.count(feature_1) > 1 or
swap_diff_torsion > 0)):
for (i, j) in added_cooccurrence:
cooccurrence_counts[i][j] += 1
cooccurrence_counts[j][i] += 1
for (i, j) in removed_cooccurrence:
cooccurrence_counts[i][j] -= 1
cooccurrence_counts[j][i] -= 1
lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],
lattice_0[index_0])
changed = True
# Return the extracted lattice structure.
return lattices