in baseline_model/run_tree_transformer_multi_gpu.py [0:0]
def main():
title='trf-tree'
sys.modules['Tree'] = Tree
argParser = config.get_arg_parser(title)
args = argParser.parse_args()
args.summary = TrainingSummaryWriter(args.log_dir)
logging = get_logger(log_path=os.path.join(args.log_dir, "log" + time.strftime('%Y%m%d-%H%M%S') + '.txt'), print_=True, log_=True)
max_len_trg, max_len_src = 0, 0
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
with open(args.golden_c_path,'rb') as file_c:
trg = pickle.load(file_c)
src_g = np.load(args.input_g_path, allow_pickle=True)
src_f = np.load(args.input_f_path, allow_pickle=True)
graphs_asm = load_graphs(args, src_f, src_g)
SEED=1234
torch.manual_seed(SEED)
exp_list = []
SRC = Field(init_token = '<sos>', eos_token = '<eos>')
TRG = RawField()
ID = RawField()
DICT_INFO = RawField()
GRAPHS_ASM = RawField()
NODE_NUM = RawField()
for i in range(0,args.gen_num):
src_elem = src_f[i]
broken_file_flag = 0
if args.dump_trace:
dict_info={}
for path in glob.glob(os.path.join(args.cache_path,str(i)+'/*')):
if os.path.getsize(path) > 0:
with open(path, 'rb') as f:
dict_info[path] = pickle.load(f)
else:
print("broken file!" + path)
broken_file_flag = 1
break
if broken_file_flag == 1:
continue
if dict_info == {}:
continue
trg_elem = trg[i]['tree']
len_elem_src = graphs_asm[i].number_of_nodes()
exp = Example.fromlist([src_elem,trg_elem,i, dict_info, graphs_asm[i], len_elem_src], \
fields =[('src', SRC), ('trg', TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)] )
exp_list.append(exp)
len_elem_trg = trg[i]['treelen']
if len_elem_src >= max_len_src:
max_len_src = len_elem_src + 2
if len_elem_trg >= max_len_trg:
max_len_trg = len_elem_trg + 2
data_sets = Dataset(exp_list,fields = [('src',SRC),('trg',TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)])
trn, tst, vld = data_sets.split([0.8,0.15,0.05])
SRC.build_vocab(trn, min_freq = 2)
#
print("Number of training examples: %d" % (len(trn.examples)))
print("Number of validation examples: %d" % (len(vld.examples)))
print("Number of testing examples: %d" % (len(tst.examples)))
print("Unique tokens in source assembly vocabulary: %d "%(len(SRC.vocab)))
print("Max input length : %d" % (max_len_src))
print("Max output length : %d" % (max_len_trg))
if args.dist_gpu:
for p in range(int(math.log(args.n_dist_gpu, 2))):
trn = split_dataset(trn, p)
procs = []
for proc_id in range(args.n_dist_gpu):
p = mp.Process(target=run_spawn, args=(proc_id, args, trn[proc_id], vld, tst, SRC, max_len_src, max_len_trg, logging))
p.start()
procs.append(p)
for p in procs:
p.join()
else:
run_spawn(-1, args, trn, vld, tst, SRC, max_len_src, max_len_trg, logging)