def get_context()

in parlai/tasks/blended_skill_talk/agents.py [0:0]


    def get_context(self) -> dict:
        """
        Get context information to be shown at the beginning of one conversation.

        Values in return dict:
        - context_dataset: the dataset (ConvAI2, EmpatheticDialogues, or Wizard of
            Wikipedia) used to generate the context information.
        - persona_1_strings, persona_2_strings: 2 persona strings each for the two
            speakers, chosen randomly from the ConvAI2 dataset. If context_dataset ==
            "wizard_of_wikipedia", these persona strings will be matched to the WoW
            topic returned in the "additional_context" field.
        - additional_context: provides additional bits of information to give context
            for the speakers. If context_dataset == "empathetic_dialogues", this is a
            situation from the start of an ED conversation. If context_dataset ==
            "wizard_of_wikipedia", this is a topic from the WoW dataset that matches the
            persona strings. If context_dataset == "convai2", this is None.
        - person1_seed_utterance, person2_seed_utterance: two lines of a conversation
            from the dataset specified by "context_dataset". They will be shown to the
            speakers to "seed" the conversation, and the speakers continue from where
            the lines left off.
        """

        # Determine which dataset we will show context for
        rand_value = self.rng.random()
        if rand_value < 1 / 3:
            context_dataset = 'convai2'
        elif rand_value < 2 / 3:
            context_dataset = 'empathetic_dialogues'
        else:
            context_dataset = 'wizard_of_wikipedia'

        if context_dataset == 'convai2':

            # Select episode
            episode_idx = self.rng.randrange(self.convai2_teacher.num_episodes())

            # Extract personas
            persona_1_strings, persona_2_strings = self._extract_personas(episode_idx)

            # Sample persona strings
            selected_persona_1_strings = self.rng.sample(persona_1_strings, 2)
            selected_persona_2_strings = self.rng.sample(persona_2_strings, 2)

            # Select previous utterances
            num_entries = len(self.convai2_teacher.data.data[episode_idx])
            entry_idx = self.rng.randrange(1, num_entries)
            # Don't select the first entry, which often doesn't include an apprentice
            # utterance
            chosen_entry = self.convai2_teacher.get(episode_idx, entry_idx=entry_idx)
            person1_seed_utterance = chosen_entry['text']
            assert len(chosen_entry['labels']) == 1
            person2_seed_utterance = chosen_entry['labels'][0]

            return {
                'context_dataset': context_dataset,
                'persona_1_strings': selected_persona_1_strings,
                'persona_2_strings': selected_persona_2_strings,
                'additional_context': None,
                'person1_seed_utterance': person1_seed_utterance,
                'person2_seed_utterance': person2_seed_utterance,
            }

        elif context_dataset == 'empathetic_dialogues':

            # Select episode
            persona_episode_idx = self.rng.randrange(
                self.convai2_teacher.num_episodes()
            )

            # Extract personas
            persona_1_strings, persona_2_strings = self._extract_personas(
                persona_episode_idx
            )

            # Sample persona strings
            selected_persona_1_strings = self.rng.sample(persona_1_strings, 2)
            selected_persona_2_strings = self.rng.sample(persona_2_strings, 2)

            # Select previous utterances
            episode_idx = self.rng.randrange(self.ed_teacher.num_episodes())
            entry_idx = 0  # We'll only use the first pair of utterances
            entry = self.ed_teacher.get(episode_idx, entry_idx=entry_idx)
            situation = entry['situation']
            speaker_utterance = entry['text']
            assert len(entry['labels']) == 1
            listener_response = entry['labels'][0]

            return {
                'context_dataset': context_dataset,
                'persona_1_strings': selected_persona_1_strings,
                'persona_2_strings': selected_persona_2_strings,
                'additional_context': situation,
                'person1_seed_utterance': speaker_utterance,
                'person2_seed_utterance': listener_response,
            }

        elif context_dataset == 'wizard_of_wikipedia':

            # Pull different personas until you get a pair for which at least one
            # sentence has a WoW topic bound to it
            num_tries = 0
            while True:

                num_tries += 1

                # Extract a random (matched) pair of personas
                persona_episode_idx = self.rng.randrange(
                    self.convai2_teacher.num_episodes()
                )
                all_persona_strings = dict()
                all_persona_strings[1], all_persona_strings[2] = self._extract_personas(
                    persona_episode_idx
                )

                # See if any of the persona strings have a matching WoW topic
                matching_persona_string_idxes = []
                for persona_idx, persona_strings in all_persona_strings.items():
                    for str_idx, str_ in enumerate(persona_strings):
                        wow_topics = self.persona_strings_to_wow_topics[str_]
                        if len(wow_topics) > 0:
                            matching_persona_string_idxes.append((persona_idx, str_idx))
                if len(matching_persona_string_idxes) > 0:
                    break

            print(
                f'{num_tries:d} try/tries needed to find a pair of personas with an '
                f'associated WoW topic.'
            )

            # Pick out the WoW topic and matching persona string
            matching_persona_idx, matching_persona_string_idx = self.rng.sample(
                matching_persona_string_idxes, k=1
            )[0]
            matching_persona_string = all_persona_strings[matching_persona_idx][
                matching_persona_string_idx
            ]
            wow_topic = self.rng.sample(
                self.persona_strings_to_wow_topics[matching_persona_string], k=1
            )[0]

            # Sample persona strings, making sure that we keep the one connected to the
            # WoW topic
            if matching_persona_idx == 1:
                remaining_persona_1_strings = [
                    str_
                    for str_ in all_persona_strings[1]
                    if str_ != matching_persona_string
                ]
                selected_persona_1_strings = [
                    matching_persona_string,
                    self.rng.sample(remaining_persona_1_strings, k=1)[0],
                ]
                self.rng.shuffle(selected_persona_1_strings)
                selected_persona_2_strings = self.rng.sample(all_persona_strings[2], 2)
            else:
                selected_persona_1_strings = self.rng.sample(all_persona_strings[1], 2)
                remaining_persona_2_strings = [
                    str_
                    for str_ in all_persona_strings[2]
                    if str_ != matching_persona_string
                ]
                selected_persona_2_strings = [
                    matching_persona_string,
                    self.rng.sample(remaining_persona_2_strings, k=1)[0],
                ]
                self.rng.shuffle(selected_persona_2_strings)

            # Sample WoW previous utterances, given the topic
            episode_idx = self.rng.sample(
                self.wow_topics_to_episode_idxes[wow_topic], k=1
            )[0]
            entry_idx = 1
            # Select the second entry, which (unlike the first entry) will always have
            # two valid utterances and which will not usually be so far along in the
            # conversation that the new Turkers will be confused
            entry = self.wow_teacher.get(episode_idx, entry_idx=entry_idx)
            apprentice_utterance = entry['text']
            assert len(entry['labels']) == 1
            wizard_utterance = entry['labels'][0]

            return {
                'context_dataset': context_dataset,
                'persona_1_strings': selected_persona_1_strings,
                'persona_2_strings': selected_persona_2_strings,
                'additional_context': wow_topic,
                'person1_seed_utterance': apprentice_utterance,
                'person2_seed_utterance': wizard_utterance,
            }