in curiosity/reader.py [0:0]
def text_to_instance(self, dialog: Dict, ignore_fact: bool = False):
msg_texts = []
msg_senders = []
msg_likes = []
msg_acts = []
msg_act_mask = []
msg_facts = []
msg_fact_labels = []
metadata_fact_labels = []
if len(dialog["messages"]) == 0:
raise ValueError("There are no dialog messages")
known_entities = [
Token(text="ENTITY/" + t.replace(" ", "_"), idx=idx)
for idx, t in enumerate(dialog["known_entities"])
]
if len(known_entities) == 0:
known_entities.append(Token(text="@@YOUKNOWNOTHING@@", idx=0))
known_entities_field = TextField(known_entities, self._mention_indexers)
focus_entity = dialog["focus_entity"]
focus_entity_field = TextField(
[Token(text="ENTITY/" + focus_entity.replace(" ", "_"), idx=0)],
self._mention_indexers,
)
for msg in dialog["messages"]:
tokenized_msg = self._tokenizer.tokenize(msg["message"])
msg_texts.append(TextField(tokenized_msg, self._token_indexers))
msg_senders.append(0 if msg["sender"] == USER else 1)
msg_likes.append(
LabelField(
"liked" if msg["liked"] else "not_liked",
label_namespace="like_labels",
)
)
if msg["dialog_acts"] is None:
dialog_acts = ["@@NODA@@"]
act_mask = 0
else:
dialog_acts = msg["dialog_acts"]
act_mask = 1
dialog_acts_field = MultiLabelFieldListCompat(
dialog_acts, label_namespace=DIALOG_ACT_LABELS
)
msg_acts.append(dialog_acts_field)
msg_act_mask.append(act_mask)
curr_facts_text = []
curr_facts_labels = []
curr_metadata_fact_labels = []
if msg["sender"] == ASSISTANT:
for idx, f in enumerate(msg["facts"]):
if ignore_fact:
fact_text = "dummy fact"
else:
fact = self._fact_lookup[f["fid"]]
fact_text = fact.text
# TODO: These are already space tokenized
tokenized_fact = self._tokenizer.tokenize(fact_text)
# 99% of text length is 77
tokenized_fact = tokenized_fact[:80]
curr_facts_text.append(
TextField(tokenized_fact, self._token_indexers)
)
if f["used"]:
curr_facts_labels.append(idx)
curr_metadata_fact_labels.append(idx)
else:
# Users don't have facts, but lets avoid divide by zero
curr_facts_text.append(
TextField([Token(text="@@NOFACT@@", idx=0)], self._token_indexers)
)
msg_facts.append(ListField(curr_facts_text))
# Add in a label if there are no correct indices
if len(curr_facts_labels) == 0:
curr_metadata_fact_labels.append(-1)
n_facts = len(curr_facts_text)
fact_label_arr = np.zeros(n_facts, dtype=np.float32)
if len(curr_facts_labels) > 0:
fact_label_arr[curr_facts_labels] = 1
msg_fact_labels.append(ArrayField(fact_label_arr, dtype=np.float32))
metadata_fact_labels.append(curr_metadata_fact_labels)
return Instance(
{
"messages": ListField(msg_texts),
"facts": ListField(msg_facts),
"fact_labels": ListField(msg_fact_labels),
"likes": ListField(msg_likes),
"dialog_acts": ListField(msg_acts),
"dialog_acts_mask": to_long_field(msg_act_mask),
"senders": to_long_field(msg_senders),
"focus_entity": focus_entity_field,
"known_entities": known_entities_field,
"metadata": MetadataField(
{
"dialog_id": dialog["dialog_id"],
"n_message": len(msg_texts),
"fact_labels": metadata_fact_labels,
"known_entities": dialog["known_entities"],
"focus_entity": dialog["focus_entity"],
}
),
}
)