in gen_delex.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--all", default=False, type=bool, required=False, help="use all dialogues rather than only augmented dialogues")
parser.add_argument("--delexlevel", default=2, type=int, required=False, help="0: no delex; 1: delex values in \"slots\"; 2: delex values in both \"slots\" and \"actions\"")
parser.add_argument("--data", default="./accentor-sgd/", type=str, required=False, help="path to SGD")
parser.add_argument("--target", default="./simpletod/", type=str, required=False, help="path to output")
args = parser.parse_args()
datafolder = args.data
targetfolder = args.target
for folder in ["train", "dev", "test"]:
if not os.path.exists(targetfolder + folder):
os.makedirs(targetfolder + folder)
inlm = []
inlme = []
inlma = []
inlmb = []
incc = []
inlmf = []
fns = os.listdir(datafolder + folder)
fns.sort()
for fn in fns:
if not fn.startswith("dialogue"):
with open(datafolder + folder + "/" + fn, "r", encoding='utf8') as f:
data = json.load(f)
with open(targetfolder + folder + "/" + fn, "w", encoding='utf8') as f:
json.dump(data, f, indent=1)
continue
with open(datafolder + folder + "/" + fn, "r", encoding='utf8') as f:
data = json.load(f)
i = 0
while i < len(data):
dbs = []
slots = {}
canmap = {}
vmap = {}
for j in range(len(data[i]["turns"])):
if data[i]["turns"][j]["speaker"] != "SYSTEM":
continue
if "service_results" in data[i]["turns"][j]["frames"][0]:
dbs += data[i]["turns"][j]["frames"][0]["service_results"]
if len(data[i]["turns"][j]["frames"][0]["slots"]) != 0:
slots = {}
for k in range(len(data[i]["turns"][j]["frames"][0]["actions"])):
assert(len(data[i]["turns"][j]["frames"][0]["actions"][k]["canonical_values"]) == len(data[i]["turns"][j]["frames"][0]["actions"][k]["values"]))
for l in range(len(data[i]["turns"][j]["frames"][0]["actions"][k]["canonical_values"])):
canmap[data[i]["turns"][j]["frames"][0]["actions"][k]["values"][l]] = data[i]["turns"][j]["frames"][0]["actions"][k]["canonical_values"][l]
vmap[data[i]["turns"][j]["frames"][0]["actions"][k]["canonical_values"][l]] = data[i]["turns"][j]["frames"][0]["actions"][k]["values"][l]
for k in range(len(data[i]["turns"][j]["frames"][0]["slots"])):
s = data[i]["turns"][j]["frames"][0]["slots"][k]["slot"]
slots[s] = data[i]["turns"][j]["utterance"][data[i]["turns"][j]["frames"][0]["slots"][k]["start"]:data[i]["turns"][j]["frames"][0]["slots"][k]["exclusive_end"]]
db = {}
for k in range(len(dbs)):
matched = True
for s in slots:
if s not in dbs[k]:
matched = False
break
if dbs[k][s] != canmap[slots[s]]:
matched = False
break
if matched:
db = copy.deepcopy(dbs[k])
for s in db:
if db[s] in vmap:
db[s] = vmap[db[s]]
break
data[i]["turns"][j]["frames"][0]["selecteddbslots"] = slots
data[i]["turns"][j]["frames"][0]["selecteddb"] = db
for j in range(1, len(data[i]["turns"]), 2):
domain = data[i]["turns"][j]["frames"][0]["service"].split("_")[0].lower()
assert(data[i]["turns"][j]["speaker"] == "SYSTEM")
assert(len(data[i]["turns"][j]["frames"]) == 1)
slots = copy.deepcopy(data[i]["turns"][j]["frames"][0]["slots"])
slots.sort(key = lambda x : -x["start"])
delex = data[i]["turns"][j]["utterance"]
delexed = set()
if args.delexlevel >= 1:
for k in range(1, len(slots)):
assert(slots[k-1]["start"] >= slots[k]["exclusive_end"])
for k in range(len(slots)):
domain_slot = domain + "_" + slots[k]["slot"]
delex = delex[:slots[k]["start"]] + "[" + domain_slot + "]" + delex[slots[k]["exclusive_end"]:]
delexed.add(domain_slot)
if args.delexlevel >= 2:
slots2 = copy.deepcopy(data[i]["turns"][j]["frames"][0]["actions"])
slots2 = [x for x in slots2 if len(x["values"]) > 0]
slots2.sort(key = lambda x : -len(x["values"][0]))
for k in range(len(slots2)):
domain_slot = domain + "_" + slots2[k]["slot"]
if domain_slot in delexed:
continue
for l in range(len(slots2[k]["values"])):
delex = delex.replace(slots2[k]["values"][l], "[" + domain_slot + "]")
delexed.add(domain_slot)
data[i]["turns"][j]["delex"] = delex
target = ''
belief = []
for k in range(len(data[i]["turns"][j-1]["frames"])):
for slot in data[i]["turns"][j-1]["frames"][k]["state"]["slot_values"]:
belief += [[data[i]["turns"][j-1]["frames"][k]["service"].split("_")[0].lower(), slot, data[i]["turns"][j-1]["frames"][k]["state"]["slot_values"][slot]]]
belief.sort(key = lambda x : x[0] + " " + x[1])
for k in range(len(belief)):
belief[k][2].sort()
belief[k][2] = belief[k][2][0]
belief = [x[0] + " " + x[1] + " " + x[2] for x in belief]
target += '<|belief|> ' + ", ".join(belief) + ' <|endofbelief|> '
action = copy.deepcopy(data[i]["turns"][j]["frames"][0]["actions"])
action.sort(key = lambda x : x["act"])
action = [domain + " " + x["act"].lower() + " " + x["slot"] for x in action]
targetaug = []
delexaug = []
tcpos = []
tcneg = []
for k in range(len(data[i]["turns"][j]["beginning"])):
if "social" in data[i]["turns"][j]["beginning"][k]["justification"] or "useful" in data[i]["turns"][j]["beginning"][k]["justification"]:
delexaug += [data[i]["turns"][j]["beginning"][k]["candidate"].strip() + ' ' + delex]
targetaug += [target + '<|action|> ' + "chitchat, " + ", ".join(action) + ' <|endofaction|> ' + '<|response|> ' + data[i]["turns"][j]["beginning"][k]["candidate"].strip() + ' ' + delex + ' <|endofresponse|>']
tcpos += [' <|task|> ' + delex + ' <|endoftask|> ' + '<|chitchat|> ' + data[i]["turns"][j]["beginning"][k]["candidate"].strip() + ' <|endofchitchat|> ']
else:
tcneg += [' <|task|> ' + delex + ' <|endoftask|> ' + '<|chitchat|> ' + data[i]["turns"][j]["beginning"][k]["candidate"].strip() + ' <|endofchitchat|> ']
for k in range(len(data[i]["turns"][j]["end"])):
if "social" in data[i]["turns"][j]["end"][k]["justification"] or "useful" in data[i]["turns"][j]["end"][k]["justification"]:
delexaug += [delex + ' ' + data[i]["turns"][j]["end"][k]["candidate"].strip()]
targetaug += [target + '<|action|> ' + ", ".join(action) + ", chitchat" + ' <|endofaction|> ' + '<|response|> ' + delex + ' ' + data[i]["turns"][j]["end"][k]["candidate"].strip() + ' <|endofresponse|>']
tcpos += [' <|task|> ' + delex + ' <|endoftask|> ' + '<|chitchat|> ' + data[i]["turns"][j]["end"][k]["candidate"].strip() + ' <|endofchitchat|> ']
else:
tcneg += [' <|task|> ' + delex + ' <|endoftask|> ' + '<|chitchat|> ' + data[i]["turns"][j]["end"][k]["candidate"].strip() + ' <|endofchitchat|> ']
target += '<|action|> ' + ", ".join(action) + ' <|endofaction|> '
target += '<|response|> ' + delex + ' <|endofresponse|>'
data[i]["turns"][j]["target"] = target
data[i]["turns"][j]["targetaug"] = targetaug
data[i]["turns"][j]["delexaug"] = delexaug
context = '<|context|> '
for k in range(j):
if k % 2 == 0:
context += '<|user|> '
else:
context += '<|system|> '
context += data[i]["turns"][k]["utterance"] + " "
context += '<|endofcontext|>'
data[i]["turns"][j]["context"] = context
inlm += [(context + target).replace("\n", " ").replace("\r", "")]
assert("\n" not in inlm[-1])
inlme += [(context).replace("\n", " ").replace("\r", "")]
if len(targetaug) != 0:
for k in range(len(targetaug)):
inlma += [(context + targetaug[k]).replace("\n", " ").replace("\r", "")]
inlmb += [(context + targetaug[k]).replace("\n", " ").replace("\r", "")]
inlmf += [(context + tcpos[k] + targetaug[k]).replace("\n", " ").replace("\r", "")]
for l in range(len(tcneg)):
inlmf += [(context + tcneg[l] + targetaug[k]).replace("\n", " ").replace("\r", "")]
else:
inlmb += [(context + target).replace("\n", " ").replace("\r", "")]
for k in range(len(tcneg)):
inlmf += [(context + tcneg[k] + target).replace("\n", " ").replace("\r", "")]
incc += [context.replace('<|context|>', '').replace('<|endofcontext|>', '').replace('<|user|>', 'user:').replace('<|system|>', 'system:').replace('\t', ' ').strip(), '[DONE]']
i += 1
with open(targetfolder + folder + "/" + fn, "w") as f:
json.dump(data, f, indent=1)
random.shuffle(inlm)
with open("lm.input."+folder+".txt", "w", encoding='utf8') as f: #SimpleTOD
f.write('\n'.join(inlm))
with open("lm.input."+folder+".eval.txt", "w", encoding='utf8') as f: #used as the input during evaluation of SimpleTOD and SimpleTOD extension
f.write('\n'.join(inlme))
with open("lm.input."+folder+".aug.txt", "w", encoding='utf8') as f: #SimpleTOD extension (augmented responses only)
f.write('\n'.join(inlma))
with open("lm.input."+folder+".both.txt", "w", encoding='utf8') as f: #SimpleTOD extension (all responses)
f.write('\n'.join(inlmb))
with open("lm.input."+folder+".cc.txt", "w", encoding='utf8') as f: #cc: chitchat
f.write('\n'.join(incc+['[EXIT]']))
with open("lm.input."+folder+".ff.txt", "w", encoding='utf8') as f: #ff: free-form
f.write('\n'.join(inlmf))