in curiosity/baseline_reader.py [0:0]
def text_to_instance(self, dialog: Dict, ignore_fact: bool = False):
msg_texts = []
msg_senders = []
msg_likes = []
msg_acts = []
msg_act_mask = []
msg_facts = []
msg_fact_labels = []
metadata_fact_labels = []
if len(dialog['messages']) == 0:
raise ValueError('There are no dialog messages')
known_entities = [
Token(text='ENTITY/' + t.replace(' ', '_'), idx=idx)
for idx, t in enumerate(dialog['known_entities'])
]
if len(known_entities) == 0:
known_entities.append(Token(text='@@YOUKNOWNOTHING@@', idx=0))
known_entities_field = TextField(known_entities, self._mention_indexers)
focus_entity = dialog['focus_entity']
focus_entity_field = TextField(
[Token(text='ENTITY/' + focus_entity.replace(' ', '_'), idx=0)],
self._mention_indexers
)
prev_msg = ''
for msg in dialog['messages']:
if MESSAGE_CUMULATIVE:
if prev_msg == '':
cur_message = msg['message']
else:
if len(prev_msg) > DIALOG_MAX_LENGTH:
prev_msg = ' '.join(prev_msg[-DIALOG_MAX_LENGTH:].split(' ')[1:])
cur_message = prev_msg + ' ' + msg['message']
prev_msg = cur_message
else:
cur_message = msg['message']
tokenized_msg = self._tokenizer.tokenize(cur_message)
msg_texts.append(TextField(tokenized_msg, self._token_indexers))
msg_senders.append(0 if msg['sender'] == USER else 1)
msg_likes.append(LabelField(
'liked' if msg['liked'] else 'not_liked',
label_namespace='like_labels'
))
if msg['dialog_acts'] is None:
dialog_acts = ['@@NODA@@']
act_mask = 0
else:
dialog_acts = msg['dialog_acts']
act_mask = 1
dialog_acts_field = MultiLabelFieldListCompat(
dialog_acts, label_namespace=DIALOG_ACT_LABELS)
msg_acts.append(dialog_acts_field)
msg_act_mask.append(act_mask)
curr_facts_text = []
curr_facts_labels = []
curr_metadata_fact_labels = []
if msg['sender'] == ASSISTANT:
for idx, f in enumerate(msg['facts']):
if ignore_fact:
fact_text = 'dummy fact'
else:
fact = self._fact_lookup[f['fid']]
fact_text = fact.text
# TODO: These are already space tokenized
tokenized_fact = self._tokenizer.tokenize(fact_text)
# 99% of text length is 77
tokenized_fact = tokenized_fact[:DIALOG_MAX_LENGTH]
curr_facts_text.append(
TextField(tokenized_fact, self._token_indexers)
)
if f['used']:
curr_facts_labels.append(idx)
curr_metadata_fact_labels.append(idx)
else:
# Users don't have facts, but lets avoid divide by zero
curr_facts_text.append(TextField(
[Token(text='@@NOFACT@@', idx=0)],
self._token_indexers
))
msg_facts.append(ListField(curr_facts_text))
# Add in a label if there are no correct indices
if len(curr_facts_labels) == 0:
curr_metadata_fact_labels.append(-1)
n_facts = len(curr_facts_text)
fact_label_arr = np.zeros(n_facts, dtype=np.float32)
if len(curr_facts_labels) > 0:
fact_label_arr[curr_facts_labels] = 1
msg_fact_labels.append(ArrayField(fact_label_arr, dtype=np.float32))
metadata_fact_labels.append(curr_metadata_fact_labels)
return Instance({
'messages': ListField(msg_texts),
'facts': ListField(msg_facts),
'fact_labels': ListField(msg_fact_labels),
'likes': ListField(msg_likes),
'dialog_acts': ListField(msg_acts),
'dialog_acts_mask': to_long_field(msg_act_mask),
'senders': to_long_field(msg_senders),
'focus_entity': focus_entity_field,
'known_entities': known_entities_field,
'metadata': MetadataField({
'dialog_id': dialog['dialog_id'],
'n_message': len(msg_texts),
'fact_labels': metadata_fact_labels,
'known_entities': dialog['known_entities'],
'focus_entity': dialog['focus_entity']
})
})