in preprocess_multiwoz/extract_examples.py [0:0]
def read_file(file_name, gating_dict, SLOTS, dataset, lang, mem_lang, sequicity, training, max_line = None, args = {"except_domain" : "", "only_domain" : ""}):
"""
Reads examples from train / dev / test files
Acknowledgement: most of this code is taken from the trade-dst repo (https://github.com/jasonwu0731/trade-dst)
implementation of the function read_langs
"""
print(("Reading from {}".format(file_name)))
data = []
max_resp_len, max_value_len = 0, 0
domain_counter = {}
with open(file_name) as f:
dials = json.load(f)
cnt_lin = 1
for dial_dict in dials:
dialog_history = ""
last_belief_dict = {}
# Filtering and counting domains
for domain in dial_dict["domains"]:
if domain not in EXPERIMENT_DOMAINS:
continue
if domain not in domain_counter.keys():
domain_counter[domain] = 0
domain_counter[domain] += 1
# Note: this does nothing if only_domain = "" and except_domain = ""
if args["only_domain"] != "" and args["only_domain"] not in dial_dict["domains"]:
continue
if (args["except_domain"] != "" and dataset == "test" and args["except_domain"] not in dial_dict["domains"]) or \
(args["except_domain"] != "" and dataset != "test" and [args["except_domain"]] == dial_dict["domains"]):
continue
# Reading data
for ti, turn in enumerate(dial_dict["dialogue"]):
turn_domain = turn["domain"]
turn_id = turn["turn_idx"]
agent_utt = ""
user_utt = "[User]: {0}".format(turn["transcript"])
if len(turn["system_transcript"]):
agent_utt = "[Agent]: {0}".format(turn["system_transcript"])
turn_uttr = agent_utt+" ; "+user_utt
else:
turn_uttr = user_utt
# user transcript is always second, sys transcript is always first
turn_uttr_strip = turn_uttr.strip()
# history contains all systranscript ; user transcripts
dialog_history += (turn_uttr_strip + " ; ")
source_text = dialog_history.strip()
turn_belief_dict = fix_general_label_error(turn["belief_state"], False, SLOTS)
# Generate domain-dependent slot list
slot_temp = SLOTS
if dataset == "train" or dataset == "dev":
if args["except_domain"] != "":
slot_temp = [k for k in SLOTS if args["except_domain"] not in k]
turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["except_domain"] not in k])
elif args["only_domain"] != "":
slot_temp = [k for k in SLOTS if args["only_domain"] in k]
turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["only_domain"] in k])
else:
if args["except_domain"] != "":
slot_temp = [k for k in SLOTS if args["except_domain"] in k]
turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["except_domain"] in k])
elif args["only_domain"] != "":
slot_temp = [k for k in SLOTS if args["only_domain"] in k]
turn_belief_dict = OrderedDict([(k, v) for k, v in turn_belief_dict.items() if args["only_domain"] in k])
turn_belief_list = [str(k)+'-'+str(v) for k, v in turn_belief_dict.items()]
for i in range(len(turn_belief_list)):
domain, label, value = turn_belief_list[i].split("-")
if label in SLOT_TO_NATURAL:
label = SLOT_TO_NATURAL[label]
turn_belief_list[i] = domain + "-" + label + "-" + value
class_label, generate_y, slot_mask, gating_label = [], [], [], []
start_ptr_label, end_ptr_label = [], []
for slot in slot_temp:
if slot in turn_belief_dict.keys():
generate_y.append(turn_belief_dict[slot])
if turn_belief_dict[slot] == "dontcare":
gating_label.append(gating_dict["dontcare"])
elif turn_belief_dict[slot] == "none":
gating_label.append(gating_dict["none"])
else:
gating_label.append(gating_dict["ptr"])
if max_value_len < len(turn_belief_dict[slot]):
max_value_len = len(turn_belief_dict[slot])
else:
generate_y.append("none")
gating_label.append(gating_dict["none"])
data_detail = {
"ID":dial_dict["dialogue_idx"],
"domains":dial_dict["domains"],
"turn_domain":turn_domain,
"turn_id":turn_id,
"dialog_history":source_text,
"turn_belief":turn_belief_list,
"turn_uttr":turn_uttr_strip
}
data.append(data_detail)
if max_resp_len < len(source_text.split()):
max_resp_len = len(source_text.split())
cnt_lin += 1
if(max_line and cnt_lin>=max_line):
break
print("domain_counter", domain_counter)
return data, max_resp_len, slot_temp