in data.py [0:0]
def extract_interactions(self):
"""
Extract context responses
:return:
"""
opt = self.prepare_args()
agent, world = self.create_agent_task(opt)
dialogs = []
cur_dial = []
true_dial = []
contexts = []
context_hashes = []
model_dialogs = []
num_dialogs = world.num_episodes()
pb = tqdm(total=num_dialogs)
last_true = ""
dialog_id = 0
context_id = 0
data_rows = []
while True:
world.parley()
for a in world.acts:
# if personachat, remove the persona
if self.args.data_name == "convai2" and "your persona" in a["text"]:
text = a["text"].split("\n")[-1]
else:
text = a["text"]
if "__SILENCE__" in text:
break
if "id" in a and a["id"] == self.args.data_name:
if len(last_true) > 0:
true_dial.append(last_true)
true_dial.append(text)
if "eval_labels" in a:
last_true = a["eval_labels"][0]
else:
last_true = a["labels"][0]
else:
cont = copy.copy(true_dial)
cont = " \n".join(cont)
agent_name = self.args.agent
if agent_name == "repeat":
agent_name = "true_response"
m = hashlib.md5()
m.update(cont.encode("utf-8"))
cont_hash = m.hexdigest()
row = {
"dialog_id": dialog_id,
"context_id": context_id,
"context": cont,
agent_name: copy.copy(text),
"context_hash": cont_hash,
}
data_rows.append(row)
# if cont_hash not in self.hash2reponses:
# self.hash2reponses[cont_hash] = {}
# self.hash2reponses[cont_hash][self.args.agent] = copy.copy(text)
# if cont_hash not in self.all_hashes:
# contexts.append(cont)
# context_hashes.append(cont_hash)
# self.all_hashes.add(cont_hash)
context_id += 1
# print(world.display())
if world.episode_done():
if len(contexts) > 0:
self.interactions.append(contexts)
self.interaction_hashes.append(context_hashes)
contexts = []
context_hashes = []
true_dial = []
last_true = ""
pb.update(1)
dialog_id += 1
context_id = 0
# if dialog_id > 100:
# break
if world.epoch_done():
print("Epoch done")
break
pb.close()
df = pd.DataFrame(data_rows)
return df