in gen_predict.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--inference", default="dev.inference.gpt2_10epoch_1e-3_fp16.json", type=str, required=False, help='inference file')
parser.add_argument("--datafolder", default="./simpletod/", type=str, required=False, help='data folder')
parser.add_argument("--predictionfolder", default="./prediction/", type=str, required=False, help='prediction folder')
parser.add_argument("--split", default="dev", type=str, required=False, help="[dev,test]")
args = parser.parse_args()
inference = args.inference
datafolder = args.datafolder
predictionfolder = args.predictionfolder
folder = args.split + "/"
if inference.endswith(".txt"):
with open(inference, "r") as f:
predict = f.read().strip().split("\n")
predict = [a.strip() for a in predict]
else:
with open(inference, "r") as f:
predict = json.load(f)
idx = 0
cnt = 0
seen_services = set()
with open(datafolder + "train/" + "schema.json", "r") as f:
schema = json.load(f)
for i in range(len(schema)):
seen_services.add(schema[i]["service_name"])
domain_slots = set()
with open(datafolder + folder + "schema.json", "r") as f:
schema = json.load(f)
for i in range(len(schema)):
for j in range(len(schema[i]["slots"])):
assert(" " not in schema[i]["slots"][j])
domain_slots.add(schema[i]["service_name"].split("_")[0].lower() + " " + schema[i]["slots"][j]["name"].lower())
fns = os.listdir(datafolder + folder)
fns.sort()
act_precision = []
act_recall = []
seen_act_precision = []
seen_act_recall = []
unseen_act_precision = []
unseen_act_recall = []
bleu = []
bleua = []
bleub = []
seenbleu = []
seenbleua = []
seenbleub = []
unseenbleu = []
unseenbleua = []
unseenbleub = []
for fn in fns:
if not fn.startswith("dialogue"):
continue
if fn.startswith("dialogues_and_metrics.json"):
continue
with open(datafolder + folder + fn, "r") as f:
data = json.load(f)
for i in range(len(data)):
for j in range(1, len(data[i]["turns"]), 2):
cnt += 1
if idx >= len(predict):
continue
belief = predict[idx].split("<|belief|>")
if len(belief) >= 2 and "<|endofbelief|>" in belief[1]:
belief = belief[1].split("<|endofbelief|>")[0].strip()
else:
belief = ""
action = predict[idx].split("<|action|>")
if len(action) >= 2 and "<|endofaction|>" in action[1]:
action = action[1].split("<|endofaction|>")[0].strip()
else:
action = ""
response = predict[idx].split("<|response|>")
if len(response) >= 2:
response = response[1].split("<|")[0].strip()
else:
response = ""
data[i]["turns"][j]["response"] = response
seen = True
for k in range(len(data[i]["turns"][j-1]["frames"])):
if data[i]["turns"][j-1]["frames"][k]["service"] not in seen_services:
seen = False
parsedbelief = belief.split(", ")
for k in range(len(parsedbelief)):
parsed = False
for ds in domain_slots:
if parsedbelief[k].startswith(ds):
parsedbelief[k] = [ds, parsedbelief[k][len(ds):].strip()]
parsed = True
break
if not parsed:
parsedbelief[k] = [parsedbelief[k]]
k = 1
while k < len(parsedbelief):
if len(parsedbelief[k]) == 1:
parsedbelief[k-1] += parsedbelief[k]
del parsedbelief[k]
else:
k += 1
if len(parsedbelief) >= 1:
if parsedbelief[0][0] not in domain_slots:
del parsedbelief[0]
parsedbelief = {x[0]:x[1:] for x in parsedbelief}
parsedaction = action.split(", ")
for k in range(len(parsedaction)):
parsedaction[k] = parsedaction[k].strip().split()
k = 0
while k < len(parsedaction):
if len(parsedaction[k]) <= 1 or len(parsedaction[k]) > 3:
del parsedaction[k]
else:
k += 1
act_gt = set()
for k in range(len(data[i]["turns"][j]["frames"][0]["actions"])):
act_gt.add((data[i]["turns"][j]["frames"][0]["actions"][k]["act"].lower() + " " + data[i]["turns"][j]["frames"][0]["actions"][k]["slot"]).strip())
act_p = set()
for k in range(len(parsedaction)):
act_p.add(' '.join(parsedaction[k][1:]))
act_precision += [len(act_p & act_gt) / len(act_p) if len(act_p) != 0 else 1]
act_recall += [len(act_p & act_gt) / len(act_gt) if len(act_gt) != 0 else 0]
if seen:
seen_act_precision += [len(act_p & act_gt) / len(act_p) if len(act_p) != 0 else 1]
seen_act_recall += [len(act_p & act_gt) / len(act_gt) if len(act_gt) != 0 else 0]
else:
unseen_act_precision += [len(act_p & act_gt) / len(act_p) if len(act_p) != 0 else 1]
unseen_act_recall += [len(act_p & act_gt) / len(act_gt) if len(act_gt) != 0 else 0]
bleu += [bleuscorer([response.lower()], [[data[i]["turns"][j]["delex"].lower()]])]
if len(data[i]["turns"][j]["delexaug"]) > 0:
bleua += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"]]])]
bleub += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"] + [data[i]["turns"][j]["delex"].lower()]]])]
if seen:
seenbleu += [bleuscorer([response.lower()], [[data[i]["turns"][j]["delex"].lower()]])]
if len(data[i]["turns"][j]["delexaug"]) > 0:
seenbleua += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"]]])]
seenbleub += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"] + [data[i]["turns"][j]["delex"].lower()]]])]
else:
unseenbleu += [bleuscorer([response.lower()], [[data[i]["turns"][j]["delex"].lower()]])]
if len(data[i]["turns"][j]["delexaug"]) > 0:
unseenbleua += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"]]])]
unseenbleub += [bleuscorer([response.lower()], [[a.lower() for a in data[i]["turns"][j]["delexaug"] + [data[i]["turns"][j]["delex"].lower()]]])]
for k in range(len(data[i]["turns"][j-1]["frames"])):
data[i]["turns"][j-1]["frames"][k]["state"]["slot_values"] = {}
for ds in parsedbelief:
if ds.split()[0].lower() == data[i]["turns"][j-1]["frames"][k]["service"].split("_")[0].lower():
data[i]["turns"][j-1]["frames"][k]["state"]["slot_values"][ds.split()[1]] = parsedbelief[ds]
idx += 1
if not os.path.exists(predictionfolder + folder):
os.makedirs(predictionfolder + folder)
with open(predictionfolder + folder + fn, "w") as f:
json.dump(data, f, indent=1)
act_precision = sum(act_precision) / len(act_precision)
act_recall = sum(act_recall) / len(act_recall)
print("act", act_precision, act_recall, 2*act_precision*act_recall/(act_precision+act_recall))
print('bleu:', sum(bleu)/len(bleu)) #BLEU-4_{orig}
print('bleua:', sum(bleua)/len(bleua)) #BLEU-4_{aug}
#print('bleub:', sum(bleub)/len(bleub))
seen_act_precision = sum(seen_act_precision) / len(seen_act_precision)
seen_act_recall = sum(seen_act_recall) / len(seen_act_recall)
print("act (seen):", seen_act_precision, seen_act_recall, 2*seen_act_precision*seen_act_recall/(seen_act_precision+seen_act_recall))
unseen_act_precision = sum(unseen_act_precision) / len(unseen_act_precision)
unseen_act_recall = sum(unseen_act_recall) / len(unseen_act_recall)
print("act (unseen):", unseen_act_precision, unseen_act_recall, 2*unseen_act_precision*unseen_act_recall/(unseen_act_precision+unseen_act_recall))
print('bleu (seen):', sum(seenbleu)/len(seenbleu))
print('bleua (seen):', sum(seenbleua)/len(seenbleua))
#print('bleub (seen):', sum(seenbleub)/len(seenbleub))
print('bleu (unseen):', sum(unseenbleu)/len(unseenbleu))
print('bleua (unseen):', sum(unseenbleua)/len(unseenbleua))