in parlai/tasks/blended_skill_talk/agents.py [0:0]
def get_context(self) -> dict:
"""
Get context information to be shown at the beginning of one conversation.
Values in return dict:
- context_dataset: the dataset (ConvAI2, EmpatheticDialogues, or Wizard of
Wikipedia) used to generate the context information.
- persona_1_strings, persona_2_strings: 2 persona strings each for the two
speakers, chosen randomly from the ConvAI2 dataset. If context_dataset ==
"wizard_of_wikipedia", these persona strings will be matched to the WoW
topic returned in the "additional_context" field.
- additional_context: provides additional bits of information to give context
for the speakers. If context_dataset == "empathetic_dialogues", this is a
situation from the start of an ED conversation. If context_dataset ==
"wizard_of_wikipedia", this is a topic from the WoW dataset that matches the
persona strings. If context_dataset == "convai2", this is None.
- person1_seed_utterance, person2_seed_utterance: two lines of a conversation
from the dataset specified by "context_dataset". They will be shown to the
speakers to "seed" the conversation, and the speakers continue from where
the lines left off.
"""
# Determine which dataset we will show context for
rand_value = self.rng.random()
if rand_value < 1 / 3:
context_dataset = 'convai2'
elif rand_value < 2 / 3:
context_dataset = 'empathetic_dialogues'
else:
context_dataset = 'wizard_of_wikipedia'
if context_dataset == 'convai2':
# Select episode
episode_idx = self.rng.randrange(self.convai2_teacher.num_episodes())
# Extract personas
persona_1_strings, persona_2_strings = self._extract_personas(episode_idx)
# Sample persona strings
selected_persona_1_strings = self.rng.sample(persona_1_strings, 2)
selected_persona_2_strings = self.rng.sample(persona_2_strings, 2)
# Select previous utterances
num_entries = len(self.convai2_teacher.data.data[episode_idx])
entry_idx = self.rng.randrange(1, num_entries)
# Don't select the first entry, which often doesn't include an apprentice
# utterance
chosen_entry = self.convai2_teacher.get(episode_idx, entry_idx=entry_idx)
person1_seed_utterance = chosen_entry['text']
assert len(chosen_entry['labels']) == 1
person2_seed_utterance = chosen_entry['labels'][0]
return {
'context_dataset': context_dataset,
'persona_1_strings': selected_persona_1_strings,
'persona_2_strings': selected_persona_2_strings,
'additional_context': None,
'person1_seed_utterance': person1_seed_utterance,
'person2_seed_utterance': person2_seed_utterance,
}
elif context_dataset == 'empathetic_dialogues':
# Select episode
persona_episode_idx = self.rng.randrange(
self.convai2_teacher.num_episodes()
)
# Extract personas
persona_1_strings, persona_2_strings = self._extract_personas(
persona_episode_idx
)
# Sample persona strings
selected_persona_1_strings = self.rng.sample(persona_1_strings, 2)
selected_persona_2_strings = self.rng.sample(persona_2_strings, 2)
# Select previous utterances
episode_idx = self.rng.randrange(self.ed_teacher.num_episodes())
entry_idx = 0 # We'll only use the first pair of utterances
entry = self.ed_teacher.get(episode_idx, entry_idx=entry_idx)
situation = entry['situation']
speaker_utterance = entry['text']
assert len(entry['labels']) == 1
listener_response = entry['labels'][0]
return {
'context_dataset': context_dataset,
'persona_1_strings': selected_persona_1_strings,
'persona_2_strings': selected_persona_2_strings,
'additional_context': situation,
'person1_seed_utterance': speaker_utterance,
'person2_seed_utterance': listener_response,
}
elif context_dataset == 'wizard_of_wikipedia':
# Pull different personas until you get a pair for which at least one
# sentence has a WoW topic bound to it
num_tries = 0
while True:
num_tries += 1
# Extract a random (matched) pair of personas
persona_episode_idx = self.rng.randrange(
self.convai2_teacher.num_episodes()
)
all_persona_strings = dict()
all_persona_strings[1], all_persona_strings[2] = self._extract_personas(
persona_episode_idx
)
# See if any of the persona strings have a matching WoW topic
matching_persona_string_idxes = []
for persona_idx, persona_strings in all_persona_strings.items():
for str_idx, str_ in enumerate(persona_strings):
wow_topics = self.persona_strings_to_wow_topics[str_]
if len(wow_topics) > 0:
matching_persona_string_idxes.append((persona_idx, str_idx))
if len(matching_persona_string_idxes) > 0:
break
print(
f'{num_tries:d} try/tries needed to find a pair of personas with an '
f'associated WoW topic.'
)
# Pick out the WoW topic and matching persona string
matching_persona_idx, matching_persona_string_idx = self.rng.sample(
matching_persona_string_idxes, k=1
)[0]
matching_persona_string = all_persona_strings[matching_persona_idx][
matching_persona_string_idx
]
wow_topic = self.rng.sample(
self.persona_strings_to_wow_topics[matching_persona_string], k=1
)[0]
# Sample persona strings, making sure that we keep the one connected to the
# WoW topic
if matching_persona_idx == 1:
remaining_persona_1_strings = [
str_
for str_ in all_persona_strings[1]
if str_ != matching_persona_string
]
selected_persona_1_strings = [
matching_persona_string,
self.rng.sample(remaining_persona_1_strings, k=1)[0],
]
self.rng.shuffle(selected_persona_1_strings)
selected_persona_2_strings = self.rng.sample(all_persona_strings[2], 2)
else:
selected_persona_1_strings = self.rng.sample(all_persona_strings[1], 2)
remaining_persona_2_strings = [
str_
for str_ in all_persona_strings[2]
if str_ != matching_persona_string
]
selected_persona_2_strings = [
matching_persona_string,
self.rng.sample(remaining_persona_2_strings, k=1)[0],
]
self.rng.shuffle(selected_persona_2_strings)
# Sample WoW previous utterances, given the topic
episode_idx = self.rng.sample(
self.wow_topics_to_episode_idxes[wow_topic], k=1
)[0]
entry_idx = 1
# Select the second entry, which (unlike the first entry) will always have
# two valid utterances and which will not usually be so far along in the
# conversation that the new Turkers will be confused
entry = self.wow_teacher.get(episode_idx, entry_idx=entry_idx)
apprentice_utterance = entry['text']
assert len(entry['labels']) == 1
wizard_utterance = entry['labels'][0]
return {
'context_dataset': context_dataset,
'persona_1_strings': selected_persona_1_strings,
'persona_2_strings': selected_persona_2_strings,
'additional_context': wow_topic,
'person1_seed_utterance': apprentice_utterance,
'person2_seed_utterance': wizard_utterance,
}