in data.py [0:0]
def extract_dialogs(self):
# DEPRECATED
# load data
# parlai specific stuff
self.args.agent = "repeat"
opt = self.prepare_args()
agent, world = self.create_agent_task(opt)
dialogs = []
dialog_hashes = []
cur_dial = []
num_dialogs = world.num_episodes()
pb = tqdm(total=num_dialogs)
while True:
world.parley()
for a in world.acts:
# if personachat, remove the persona
if (
self.args.data_name in ["personachat", "convai2"]
and "your persona" in a["text"]
):
text = a["text"].split("\n")[-1]
else:
text = a["text"]
cur_sents = sent_tokenize(text)
cur_dial.append("[CLS] " + " [SEP] ".join(cur_sents) + " [SEP]")
# print(world.display())
if world.episode_done():
dialogs.append(cur_dial)
# calc and store hash
m = hashlib.md5()
for cd in cur_dial:
m.update(cd.encode("utf-8"))
dialog_hashes.append(m.hexdigest())
cur_dial = []
pb.update(1)
if world.epoch_done():
print("Epoch done")
break
pb.close()
self.dialogs = dialogs
self.dialog_hashes = dialog_hashes
self.logbook.write_message_logs("{} dialogs extracted".format(len(dialogs)))
self.logbook.write_message_logs(
"{} unique hashes".format(len(set(dialog_hashes)))
)