in parlai/tasks/msc/agents.py [0:0]
def setup_data(self, datafile):
print('loading: ' + datafile)
if self.datatype.startswith('train'):
path_to_open = os.path.join(datafile, 'train.txt')
elif self.datatype.startswith('valid'):
path_to_open = os.path.join(datafile, 'valid.txt')
else:
path_to_open = os.path.join(datafile, 'test.txt')
with PathManager.open(path_to_open) as f:
raw_data = [json.loads(line.strip()) for line in f]
data = []
label_speaker_id_range = {}
predicted_summary_dict = {}
if self.use_predicted_summary:
is_session_level = not ('utt_' in self.previous_persona_type)
predsum_path = get_predicted_summary_path(self.msc_dpath, is_session_level)
logger.warning(f"use the predicted summary from {predsum_path}")
with PathManager.open(predsum_path) as jsonfile:
predicted_summary_dict = json.load(jsonfile)
def _get_time_gap(time_num, time_unit, time_token=""):
time_gap = str(time_num) + ' ' + time_unit
return f'{time_token} {time_gap}' if len(time_token) > 0 else time_gap
def _compile_persona_dialog_input(
dialog, personas, previous_dialogs, label_speaker_id
):
new_dialog = copy.deepcopy(dialog)
new_previous_dialogs = copy.deepcopy(previous_dialogs)
your_persona = ""
partner_persona = ""
if label_speaker_id == 'self':
your_persona = '\n'.join([f'your persona: {x}' for x in personas[1]])
partner_persona = '\n'.join(
[f"partner's persona: {x}" for x in personas[0]]
)
elif label_speaker_id == 'their':
your_persona = '\n'.join([f'your persona: {x}' for x in personas[0]])
partner_persona = '\n'.join(
[f"partner's persona: {x}" for x in personas[1]]
)
for prev_dialog in new_previous_dialogs:
prev_dialog['dialog'].insert(0, {"text": DUMMY_TEXT})
if len(prev_dialog['dialog']) % 2 == 1 and (
self.history_person_tokens is None
):
prev_dialog['dialog'].append({"text": DUMMY_TEXT})
new_dialog.insert(0, {"text": DUMMY_TEXT})
return your_persona, partner_persona, new_dialog, new_previous_dialogs
for dialog_dict in raw_data:
initial_data_id = dialog_dict['metadata']['initial_data_id']
if self.label_speaker_id == 'both':
label_speaker_id_range = ['their', 'self']
else:
label_speaker_id_range = [self.label_speaker_id]
for label_speaker_id in label_speaker_id_range:
if self.use_predicted_summary:
personas_to_complie = predicted_summary_dict[
str(self.session_id - 1)
][initial_data_id]
elif self.previous_persona_type.startswith('init'):
personas_to_complie = dialog_dict['init_personas']
else:
personas_to_complie = dialog_dict['personas']
(
your_persona,
partner_persona,
new_dialog,
new_previous_dialogs,
) = _compile_persona_dialog_input(
dialog_dict['dialog'],
personas_to_complie,
dialog_dict['previous_dialogs'],
label_speaker_id,
)
previous_sessions_msgs = []
if self.previous_persona_type == 'raw_history':
for d_id in range(len(new_previous_dialogs)):
previous_dialog_msg = [
x['text'] for x in new_previous_dialogs[d_id]['dialog']
]
if self.history_person_tokens:
previous_dialog_msg = [
self.history_person_tokens[i % 2] + ' ' + text
for i, text in enumerate(previous_dialog_msg)
if text != DUMMY_TEXT
]
if self.history_time_gaps_token:
time_gap_i = _get_time_gap(
new_previous_dialogs[d_id]['time_num'],
new_previous_dialogs[d_id]['time_unit'],
time_token=self.history_time_gaps_token,
)
previous_sessions_msgs.append(
'\n'.join(previous_dialog_msg + [time_gap_i])
)
else:
previous_sessions_msgs.append(
'\n'.join(previous_dialog_msg)
)
if self.previous_session_delimiter is not None:
previous_sessions_msgs = [
val
for pair in zip(
previous_sessions_msgs,
[self.previous_session_delimiter]
* len(previous_sessions_msgs),
)
for val in pair
]
previous_sessions_msgs = '\n'.join(previous_sessions_msgs)
episode = []
for i in range(0, len(new_dialog) - 1, 2):
text = new_dialog[i]['text']
partner_persona_one_line = partner_persona.replace('\n', '').split(
"partner's persona: "
)
your_persona_one_line = your_persona.replace('\n', '').split(
"your persona: "
)
action = {
'id': self.id,
'text': self.normalize_replies(text),
'labels': [self.normalize_replies(new_dialog[i + 1]['text'])],
'session_id': self.session_id,
'initial_data_id': initial_data_id,
'personas': f'{partner_persona}\n{your_persona}',
'personas_one_line': f"partner's persona: {' '.join(partner_persona_one_line)}\nyour persona: {' '.join(your_persona_one_line)}",
}
if i == 0:
action.update(
{
'time_num': dialog_dict['previous_dialogs'][-1][
'time_num'
],
'time_unit': dialog_dict['previous_dialogs'][-1][
'time_unit'
],
}
)
episode.append(action)
if self.session_openning:
break
persona_context_str = ""
if 'self' in self.previous_persona_type:
persona_context_str = your_persona
elif 'their' in self.previous_persona_type:
persona_context_str = partner_persona
elif 'both' in self.previous_persona_type:
if self.your_persona_first:
persona_context_str = (
(your_persona + '\n') if len(your_persona) > 0 else ""
) + partner_persona
else:
persona_context_str = (
(partner_persona + '\n') if len(partner_persona) > 0 else ""
) + your_persona
elif self.previous_persona_type == 'raw_history':
persona_context_str = previous_sessions_msgs
if self.include_last_time_gap:
time_gap = _get_time_gap(
dialog_dict['previous_dialogs'][-1]['time_num'],
dialog_dict['previous_dialogs'][-1]['time_unit'],
)
persona_context_str = (
(persona_context_str + '\n')
if len(persona_context_str) > 0
else ""
) + f'[{time_gap}]'
if persona_context_str and len(persona_context_str) > 0:
episode[0]['text'] = persona_context_str + '\n' + episode[0]['text']
data.append(episode)
for episode in data:
start_idx = 0
for i, turn in enumerate(episode):
yield Message(turn), i == start_idx