def setup_data()

in parlai/tasks/msc/agents.py [0:0]


    def setup_data(self, datafile):
        print('loading: ' + datafile)
        if self.datatype.startswith('train'):
            path_to_open = os.path.join(datafile, 'train.txt')
        elif self.datatype.startswith('valid'):
            path_to_open = os.path.join(datafile, 'valid.txt')
        else:
            path_to_open = os.path.join(datafile, 'test.txt')

        with PathManager.open(path_to_open) as f:
            raw_data = [json.loads(line.strip()) for line in f]

        data = []
        label_speaker_id_range = {}
        predicted_summary_dict = {}
        if self.use_predicted_summary:
            is_session_level = not ('utt_' in self.previous_persona_type)
            predsum_path = get_predicted_summary_path(self.msc_dpath, is_session_level)
            logger.warning(f"use the predicted summary from {predsum_path}")
            with PathManager.open(predsum_path) as jsonfile:
                predicted_summary_dict = json.load(jsonfile)

        def _get_time_gap(time_num, time_unit, time_token=""):
            time_gap = str(time_num) + ' ' + time_unit
            return f'{time_token} {time_gap}' if len(time_token) > 0 else time_gap

        def _compile_persona_dialog_input(
            dialog, personas, previous_dialogs, label_speaker_id
        ):
            new_dialog = copy.deepcopy(dialog)
            new_previous_dialogs = copy.deepcopy(previous_dialogs)
            your_persona = ""
            partner_persona = ""
            if label_speaker_id == 'self':
                your_persona = '\n'.join([f'your persona: {x}' for x in personas[1]])
                partner_persona = '\n'.join(
                    [f"partner's persona: {x}" for x in personas[0]]
                )
            elif label_speaker_id == 'their':
                your_persona = '\n'.join([f'your persona: {x}' for x in personas[0]])
                partner_persona = '\n'.join(
                    [f"partner's persona: {x}" for x in personas[1]]
                )
                for prev_dialog in new_previous_dialogs:
                    prev_dialog['dialog'].insert(0, {"text": DUMMY_TEXT})
                    if len(prev_dialog['dialog']) % 2 == 1 and (
                        self.history_person_tokens is None
                    ):
                        prev_dialog['dialog'].append({"text": DUMMY_TEXT})
                new_dialog.insert(0, {"text": DUMMY_TEXT})

            return your_persona, partner_persona, new_dialog, new_previous_dialogs

        for dialog_dict in raw_data:
            initial_data_id = dialog_dict['metadata']['initial_data_id']
            if self.label_speaker_id == 'both':
                label_speaker_id_range = ['their', 'self']
            else:
                label_speaker_id_range = [self.label_speaker_id]

            for label_speaker_id in label_speaker_id_range:
                if self.use_predicted_summary:
                    personas_to_complie = predicted_summary_dict[
                        str(self.session_id - 1)
                    ][initial_data_id]
                elif self.previous_persona_type.startswith('init'):
                    personas_to_complie = dialog_dict['init_personas']
                else:
                    personas_to_complie = dialog_dict['personas']

                (
                    your_persona,
                    partner_persona,
                    new_dialog,
                    new_previous_dialogs,
                ) = _compile_persona_dialog_input(
                    dialog_dict['dialog'],
                    personas_to_complie,
                    dialog_dict['previous_dialogs'],
                    label_speaker_id,
                )
                previous_sessions_msgs = []
                if self.previous_persona_type == 'raw_history':
                    for d_id in range(len(new_previous_dialogs)):
                        previous_dialog_msg = [
                            x['text'] for x in new_previous_dialogs[d_id]['dialog']
                        ]
                        if self.history_person_tokens:
                            previous_dialog_msg = [
                                self.history_person_tokens[i % 2] + ' ' + text
                                for i, text in enumerate(previous_dialog_msg)
                                if text != DUMMY_TEXT
                            ]
                        if self.history_time_gaps_token:
                            time_gap_i = _get_time_gap(
                                new_previous_dialogs[d_id]['time_num'],
                                new_previous_dialogs[d_id]['time_unit'],
                                time_token=self.history_time_gaps_token,
                            )
                            previous_sessions_msgs.append(
                                '\n'.join(previous_dialog_msg + [time_gap_i])
                            )
                        else:
                            previous_sessions_msgs.append(
                                '\n'.join(previous_dialog_msg)
                            )

                if self.previous_session_delimiter is not None:
                    previous_sessions_msgs = [
                        val
                        for pair in zip(
                            previous_sessions_msgs,
                            [self.previous_session_delimiter]
                            * len(previous_sessions_msgs),
                        )
                        for val in pair
                    ]
                previous_sessions_msgs = '\n'.join(previous_sessions_msgs)

                episode = []
                for i in range(0, len(new_dialog) - 1, 2):
                    text = new_dialog[i]['text']
                    partner_persona_one_line = partner_persona.replace('\n', '').split(
                        "partner's persona: "
                    )
                    your_persona_one_line = your_persona.replace('\n', '').split(
                        "your persona: "
                    )
                    action = {
                        'id': self.id,
                        'text': self.normalize_replies(text),
                        'labels': [self.normalize_replies(new_dialog[i + 1]['text'])],
                        'session_id': self.session_id,
                        'initial_data_id': initial_data_id,
                        'personas': f'{partner_persona}\n{your_persona}',
                        'personas_one_line': f"partner's persona: {' '.join(partner_persona_one_line)}\nyour persona: {' '.join(your_persona_one_line)}",
                    }
                    if i == 0:
                        action.update(
                            {
                                'time_num': dialog_dict['previous_dialogs'][-1][
                                    'time_num'
                                ],
                                'time_unit': dialog_dict['previous_dialogs'][-1][
                                    'time_unit'
                                ],
                            }
                        )

                    episode.append(action)
                    if self.session_openning:
                        break

                persona_context_str = ""
                if 'self' in self.previous_persona_type:
                    persona_context_str = your_persona
                elif 'their' in self.previous_persona_type:
                    persona_context_str = partner_persona
                elif 'both' in self.previous_persona_type:
                    if self.your_persona_first:
                        persona_context_str = (
                            (your_persona + '\n') if len(your_persona) > 0 else ""
                        ) + partner_persona
                    else:
                        persona_context_str = (
                            (partner_persona + '\n') if len(partner_persona) > 0 else ""
                        ) + your_persona
                elif self.previous_persona_type == 'raw_history':
                    persona_context_str = previous_sessions_msgs

                if self.include_last_time_gap:
                    time_gap = _get_time_gap(
                        dialog_dict['previous_dialogs'][-1]['time_num'],
                        dialog_dict['previous_dialogs'][-1]['time_unit'],
                    )
                    persona_context_str = (
                        (persona_context_str + '\n')
                        if len(persona_context_str) > 0
                        else ""
                    ) + f'[{time_gap}]'

                if persona_context_str and len(persona_context_str) > 0:
                    episode[0]['text'] = persona_context_str + '\n' + episode[0]['text']

                data.append(episode)

        for episode in data:
            start_idx = 0
            for i, turn in enumerate(episode):
                yield Message(turn), i == start_idx