def main()

in baseline_model/run_tree_transformer.py [0:0]


def main():
    title='trf-tree'
    sys.modules['Tree'] = Tree
    argParser = config.get_arg_parser(title)
    args = argParser.parse_args()
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.summary = TrainingSummaryWriter(args.log_dir)
    logging = get_logger(log_path=os.path.join(args.log_dir, "log" + time.strftime('%Y%m%d-%H%M%S') + '.txt'), print_=True, log_=True)

    max_len_trg, max_len_src = 0, 0
    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)
    with open(args.golden_c_path,'rb') as file_c:
        trg = pickle.load(file_c)

    src_g = np.load(args.input_g_path, allow_pickle=True)
    src_f = np.load(args.input_f_path, allow_pickle=True)

    graphs_asm = load_graphs(args, src_f, src_g) 

    SEED=1234
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    exp_list = []
    SRC        = Field(init_token = '<sos>', eos_token = '<eos>')
    TRG        = RawField()
    ID         = RawField()
    DICT_INFO  = RawField()
    GRAPHS_ASM = RawField()
    NODE_NUM   = RawField()
    # args.gen_num = 500

    for i in range(0,args.gen_num):
        src_elem = src_f[i]
        broken_file_flag = 0
        # edge_elem = src_g[i]
        if args.dump_trace:
            dict_info={}
            for path in glob.glob(os.path.join(args.cache_path,str(i)+'/*')):
                if os.path.getsize(path) > 0:
                    with open(path, 'rb') as f:
                        dict_info[path] = pickle.load(f)
                else:
                    print("broken file!" + path)
                    broken_file_flag = 1
                    break

        if broken_file_flag == 1:
            continue

        if dict_info == {}:
            print(dict_info)
            continue
        trg_elem = trg[i]['tree']
        len_elem_src = graphs_asm[i].number_of_nodes()
        exp = Example.fromlist([src_elem,trg_elem,i, dict_info, graphs_asm[i], len_elem_src], \
            fields =[('src', SRC), ('trg', TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)] )
        exp_list.append(exp)
        len_elem_trg = trg[i]['treelen']

        if len_elem_src  >= max_len_src:
            max_len_src = len_elem_src  + 2 
        if len_elem_trg >= max_len_trg:
            max_len_trg = len_elem_trg + 2

    data_sets = Dataset(exp_list,fields = [('src',SRC),('trg',TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)])
    trn, vld, tst = data_sets.split([0.8,0.05,0.15])
    SRC.build_vocab(trn, min_freq = 2)

    logging("Number of training examples: %d" % (len(trn.examples)))
    logging("Number of validation examples: %d" % (len(vld.examples)))
    logging("Number of testing examples: %d" % (len(tst.examples)))
    logging("Unique tokens in source assembly vocabulary: %d "%(len(SRC.vocab)))
    logging("Max input length : %d" % (max_len_src))
    logging("Max output length : %d" % (max_len_trg))
    print(args.device)

    num_workers = 0 

    collate = text_data_collator(trn)
    train_iterator = DataLoader(trn, batch_size=args.bsz, collate_fn=collate, num_workers=num_workers, shuffle=False)
    collate = text_data_collator(vld)
    valid_iterator = DataLoader(vld, batch_size=args.bsz, collate_fn=collate, num_workers=num_workers, shuffle=False)
    collate = text_data_collator(tst)
    test_iterator = DataLoader(tst, batch_size=args.bsz, collate_fn=collate, num_workers=num_workers, shuffle=False)

    best_valid_loss = float('inf')
    INPUT_DIM = len(SRC.vocab)

    gnn_asm = Graph_NN( annotation_size = len(SRC.vocab),
                        out_feats = args.hid_dim,
                        n_steps = args.n_gnn_layers,
                        device = args.device
                        )

    gnn_ast = Graph_NN( annotation_size = None,
                        out_feats = args.hid_dim,
                        n_steps = args.n_gnn_layers,
                        device = args.device)

    enc = Encoder(INPUT_DIM,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  args.device,
                  args.mem_dim,
                  embedding_flag=args.embedding_flag,
                  max_length = max_len_src)

    dec = Decoder_AST(
                  args.output_dim,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  args.device,
                  max_length = max_len_trg)

    SRC_PAD_IDX = 0 
    TRG_PAD_IDX = 0 

    if args.parallel_gpu:
        enc = torch.nn.DataParallel(enc)
        dec = torch.nn.DataParallel(dec)
    
    model = Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, args.device,  
                gnn=gnn_ast, gnn_asm=gnn_asm).to(args.device)
                
    model.apply(initialize_weights)
    criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
    optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    if args.training and not args.eval:
        for epoch in range(args.n_epoch):
            start_time = time.time()
            train_loss, train_acc = train_eval_tree(args, model, train_iterator, optimizer\
                                , args.device, criterion, max_len_trg, train_flag=True)

            valid_loss, valid_acc = train_eval_tree(args, model, valid_iterator, None\
                                , args.device, criterion, max_len_trg, train_flag=False)
            end_time = time.time()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss and (args.checkpoint_path is not None):
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'model.pt'))

            logging('Epoch: %d | Time: %dm %ds | learning rate %.3f' %(epoch,epoch_mins,epoch_secs, optimizer._rate*10000))
            print_performances('Training', train_loss, train_acc, start_time, logging=logging)
            print_performances('Validation', valid_loss, valid_acc, start_time, logging=logging)
            args.summary.add_scalar('train/acc', train_acc)
            args.summary.add_scalar('valid/acc', valid_acc)

    start_time = time.time()
    model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'model.pt'), map_location=args.device))
    test_loss, test_acc = test_tree(args, model, test_iterator, TRG_PAD_IDX, args.device, args.label_smoothing, criterion, args.clip)
    print_performances('Test', test_loss, test_acc, start_time, logging=logging)