in TransferQA/create_data_mwoz.py [0:0]
def divideData(data,args):
"""Given test and validation sets, divide
the data for three different sets"""
os.makedirs(args.target_path,exist_ok=True)
copyfile(os.path.join(args.main_dir,'ontology.json'), os.path.join(args.target_path,'ontology.json'))
testListFile = []
fin = open(os.path.join(args.main_dir,'testListFile.json'), 'r')
for line in fin:
testListFile.append(line[:-1])
fin.close()
valListFile = []
fin = open(os.path.join(args.main_dir,'valListFile.json'), 'r')
for line in fin:
valListFile.append(line[:-1])
fin.close()
trainListFile = open(os.path.join(args.target_path,'trainListFile'), 'w')
test_dials = []
val_dials = []
train_dials = []
# dictionaries
word_freqs_usr = OrderedDict()
word_freqs_sys = OrderedDict()
count_train, count_val, count_test = 0, 0, 0
ontology = {}
for dialogue_name in data:
# print dialogue_name
dial_item = data[dialogue_name]
domains = []
for dom_k, dom_v in dial_item['goal'].items():
if dom_v and dom_k not in IGNORE_KEYS_IN_GOAL: # check whether contains some goal entities
domains.append(dom_k)
turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } }
dial = get_dial(data[dialogue_name])
if dial:
dial_example = {"dial_id":dialogue_name, "domains":list(set(domains)) ,"turns":[]}
# dialogue = {}
# dialogue['dialogue_idx'] = dialogue_name
# dialogue['domains'] = list(set(domains)) #list(set([d['domain'] for d in dial]))
# last_bs = []
# dialogue['dialogue'] = []
for turn_i, turn in enumerate(dial):
# usr, usr_o, sys, sys_o, sys_a, domain
turn_exmaple = {"system":"none", "user":"none", "state":{"active_intent":"none", "slot_values":{} } }
turn_exmaple['system'] = dial[turn_i-1]['sys'] if turn_i > 0 else "none"
turn_exmaple['state']["slot_values"] = {s[0]:s[1] for s in turn['bvs']}
turn_exmaple['user'] = turn['usr']
dial_example['turns'].append(turn_exmaple)
for ss, vv in turn_exmaple['state']["slot_values"].items():
if ss not in ontology:
ontology[ss] = []
if vv not in ontology[ss]:
ontology[ss].append(vv)
if dialogue_name in testListFile:
test_dials.append(dial_example)
count_test += 1
elif dialogue_name in valListFile:
val_dials.append(dial_example)
count_val += 1
else:
trainListFile.write(dialogue_name + '\n')
train_dials.append(dial_example)
count_train += 1
print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test))
# save all dialogues
with open('data/dev_dials.json', 'w') as f:
json.dump(val_dials, f, indent=4)
with open('data/test_dials.json', 'w') as f:
json.dump(test_dials, f, indent=4)
with open('data/train_dials.json', 'w') as f:
json.dump(train_dials, f, indent=4)
with open('data/ontology.json', 'w') as f:
json.dump(ontology, f, indent=4)