in src/util/shorten_topic_length.py [0:0]
def shorten_single_topic(self, row):
initial = row.output
if initial in FIXED_TOPICS:
return initial
try:
words = initial.split()
initial_embedding = self.hint_embedder(initial)[0][0]
if len(words) == 1:
return initial
alt_embeddings = [initial_embedding]
alt_phrases = [initial]
document_words = set(self.get_document_for_row(row).translate(self.translator).lower().split())
for missing_word_index in range(len(words)):
# remove 1 word
word_getting_removed = words[missing_word_index]
if word_getting_removed.lower() in PRESERVE_WORDS:
continue
if len(self.spell.unknown([word_getting_removed])) == 1: # skip rare word
continue
shortened_phrase = " ".join(words[:missing_word_index] + words[missing_word_index + 1:])
embed = self.hint_embedder(shortened_phrase)[0][0]
alt_phrases.append(shortened_phrase)
alt_embeddings.append(embed)
if len(words) > 2:
# remove 2 words
for missing_word_index in range(len(words) - 1):
words_to_remove = words[missing_word_index:missing_word_index+2]
for w in words_to_remove:
if w.lower() in PRESERVE_WORDS:
continue
if len(self.spell.unknown(words_to_remove)) > 0: # either word unknown
continue
shortened_phrase = " ".join(words[:missing_word_index] + words[missing_word_index + 2:])
embed = self.hint_embedder(shortened_phrase)[0][0]
alt_phrases.append(shortened_phrase)
alt_embeddings.append(embed)
document_embedding = self.get_embedding_for_document(row).reshape(1, -1)
similarity = cosine_similarity(document_embedding, np.array(alt_embeddings)).squeeze()
similarity[0] -= self.boost_threshold
closest_indices = np.argsort(-similarity)
best_phrases = [alt_phrases[i] for i in closest_indices.tolist()]
best_phrase = best_phrases[0]
best_phrase = self.pluralize_single_word(best_phrase, document_words)
print(f"{initial} -> {best_phrase}")
return best_phrase
except:
return initial