def main()

in baseline_model/run_tree_transformer_multi_gpu.py [0:0]


def main():
    title='trf-tree'
    sys.modules['Tree'] = Tree
    argParser = config.get_arg_parser(title)
    args = argParser.parse_args()
    args.summary = TrainingSummaryWriter(args.log_dir)
    logging = get_logger(log_path=os.path.join(args.log_dir, "log" + time.strftime('%Y%m%d-%H%M%S') + '.txt'), print_=True, log_=True)

    max_len_trg, max_len_src = 0, 0
    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)
    with open(args.golden_c_path,'rb') as file_c:
        trg = pickle.load(file_c)

    src_g = np.load(args.input_g_path, allow_pickle=True)
    src_f = np.load(args.input_f_path, allow_pickle=True)

    graphs_asm = load_graphs(args, src_f, src_g) 

    SEED=1234
    torch.manual_seed(SEED)

    exp_list = []
    SRC        = Field(init_token = '<sos>', eos_token = '<eos>')
    TRG        = RawField()
    ID         = RawField()
    DICT_INFO  = RawField()
    GRAPHS_ASM = RawField()
    NODE_NUM   = RawField()

    for i in range(0,args.gen_num):
        src_elem = src_f[i]
        broken_file_flag = 0
        if args.dump_trace:
            dict_info={}
            for path in glob.glob(os.path.join(args.cache_path,str(i)+'/*')):
                if os.path.getsize(path) > 0:
                    with open(path, 'rb') as f:
                        dict_info[path] = pickle.load(f)
                else:
                    print("broken file!" + path)
                    broken_file_flag = 1
                    break

        if broken_file_flag == 1:
            continue

        if dict_info == {}:
            continue
        trg_elem = trg[i]['tree']
        len_elem_src = graphs_asm[i].number_of_nodes()
        exp = Example.fromlist([src_elem,trg_elem,i, dict_info, graphs_asm[i], len_elem_src], \
            fields =[('src', SRC), ('trg', TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)] )
        exp_list.append(exp)
        len_elem_trg = trg[i]['treelen']

        if len_elem_src  >= max_len_src:
            max_len_src = len_elem_src  + 2 
        if len_elem_trg >= max_len_trg:
            max_len_trg = len_elem_trg + 2

    data_sets = Dataset(exp_list,fields = [('src',SRC),('trg',TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)])
    trn, tst, vld = data_sets.split([0.8,0.15,0.05])
    SRC.build_vocab(trn, min_freq = 2)
    #
    print("Number of training examples: %d" % (len(trn.examples)))
    print("Number of validation examples: %d" % (len(vld.examples)))
    print("Number of testing examples: %d" % (len(tst.examples)))
    print("Unique tokens in source assembly vocabulary: %d "%(len(SRC.vocab)))
    print("Max input length : %d" % (max_len_src))
    print("Max output length : %d" % (max_len_trg))


    if args.dist_gpu:
        for p in range(int(math.log(args.n_dist_gpu, 2))):
            trn = split_dataset(trn, p)

        procs = []
        for proc_id in range(args.n_dist_gpu):
            p = mp.Process(target=run_spawn, args=(proc_id, args, trn[proc_id], vld, tst, SRC, max_len_src, max_len_trg, logging))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()
    else:
        run_spawn(-1, args, trn, vld, tst, SRC, max_len_src, max_len_trg, logging)