in T5DST/data_loader.py [0:0]
def read_data(args, path_name, SLOTS, tokenizer, description, dataset=None):
slot_lang_list = ["description_human", "rule_description", "value_description", "rule2", "rule3"]
print(("Reading all files from {}".format(path_name)))
data = []
domain_counter = {}
# read files
with open(path_name) as f:
dials = json.load(f)
if dataset=="train" and args["fewshot"]>0:
random.Random(args["seed"]).shuffle(dials)
dials = dials[:int(len(dials)*args["fewshot"])]
for dial_dict in dials:
dialog_history = ""
# Counting domains
for domain in dial_dict["domains"]:
if domain not in EXPERIMENT_DOMAINS:
continue
if domain not in domain_counter.keys():
domain_counter[domain] = 0
domain_counter[domain] += 1
# Unseen domain setting
if args["only_domain"] != "none" and args["only_domain"] not in dial_dict["domains"]:
continue
if (args["except_domain"] != "none" and dataset == "test" and args["except_domain"] not in dial_dict["domains"]) or \
(args["except_domain"] != "none" and dataset != "test" and [args["except_domain"]] == dial_dict["domains"]):
continue
# Reading data
for ti, turn in enumerate(dial_dict["turns"]):
turn_id = ti
# accumulate dialogue utterances
dialog_history += (" System: " + turn["system"] + " User: " + turn["user"])
if args["fix_label"]:
slot_values = fix_general_label_error(turn["state"]["slot_values"],SLOTS)
else:
slot_values = turn["state"]["slot_values"]
# input: dialogue history + slot
# output: value
# Generate domain-dependent slot list
slot_temp = SLOTS
if dataset == "train" or dataset == "dev":
if args["except_domain"] != "none":
slot_temp = [k for k in SLOTS if args["except_domain"] not in k]
slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["except_domain"] not in k])
elif args["only_domain"] != "none":
slot_temp = [k for k in SLOTS if args["only_domain"] in k]
slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["only_domain"] in k])
else:
if args["except_domain"] != "none":
slot_temp = [k for k in SLOTS if args["except_domain"] in k]
slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["except_domain"] in k])
elif args["only_domain"] != "none":
slot_temp = [k for k in SLOTS if args["only_domain"] in k]
slot_values = OrderedDict([(k, v) for k, v in slot_values.items() if args["only_domain"] in k])
turn_belief_list = [str(k)+'-'+str(v) for k,v in slot_values.items()]
# baseline gpt have different preprocessing, e.g., output: (slot1-value1, slot2-value2, slot3-value3, ...)
if "gpt" in args["model_name"]:
turn_slots = []
turn_slot_values = []
if len(dialog_history.split())>800:
continue
for slot in slot_temp:
# skip unrelevant slots for out of domain setting
if args["except_domain"] != "none" and dataset !="test":
if slot.split("-")[0] not in dial_dict["domains"]:
continue
input_text = dialog_history + f" {tokenizer.sep_token} {slot}" + " " + tokenizer.bos_token
output_text = input_text+ " " + turn["state"]["slot_values"].get(slot, 'none').strip() + " " + tokenizer.eos_token
slot_text = slot
value_text = turn["state"]["slot_values"].get(slot, 'none').strip()
data_detail = {
"ID":dial_dict["dial_id"],
"domains":dial_dict["domains"],
"turn_id":turn_id,
"dialog_history":dialog_history,
"turn_belief":turn_belief_list,
"intput_text":input_text,
"output_text":output_text,
"slot_text":slot_text,
"value_text":value_text
}
data.append(data_detail)
else:
for slot in slot_temp:
# skip unrelevant slots for out of domain setting
if args["except_domain"] != "none" and dataset !="test":
if slot.split("-")[0] not in dial_dict["domains"]:
continue
output_text = slot_values.get(slot, 'none').strip() + f" {tokenizer.eos_token}"
slot_text = slot
value_text = slot_values.get(slot, 'none').strip()
if args["slot_lang"]=="human":
slot_lang = description[slot]["description_human"]
input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?"
elif args["slot_lang"]=="naive":
slot_lang = description[slot]["naive"]
input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?"
elif args["slot_lang"]=="value":
slot_lang = description[slot]["naive"]
input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}"
elif args["slot_lang"]=="question":
slot_lang = description[slot]["question"]
input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}"
elif args["slot_lang"]=="slottype":
slot_lang = description[slot]["slottype"]
input_text = dialog_history + f" {tokenizer.sep_token} {slot_lang}?"
else:
input_text = dialog_history + f" {tokenizer.sep_token} {slot}"
data_detail = {
"ID":dial_dict["dial_id"],
"domains":dial_dict["domains"],
"turn_id":turn_id,
"dialog_history":dialog_history,
"turn_belief":turn_belief_list,
"intput_text":input_text,
"output_text":output_text,
"slot_text":slot_text,
"value_text":value_text,
"value_list":description[slot]["values"]
}
data.append(data_detail)
# print(len(data))
for idx in range(10):
print(data[idx])
print("domain_counter", domain_counter)
return data, slot_temp