in baseline_model/run_similarity_check.py [0:0]
def main():
title='similarity-check'
argParser = config.get_arg_parser(title)
args = argParser.parse_args()
if not os.path.exists(args.cache_path):
os.makedirs(args.cache_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dataset = GNNDataset(args.dataset_dir, asm=args.asm, max_len=args.max_tolerate_len)
with open(args.split, 'rb') as f:
split = pickle.load(f)
SEED=1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
max_len_src = args.max_tolerate_len
gnn = Graph_NN( annotation_size = len(dataset.vocab_asm) ,
out_feats = args.hid_dim,
n_steps = args.n_gnn_layers,
device = device,
tok_embedding=2,
residual=False
)
enc = Encoder(
len(dataset.vocab_asm) ,
args.hid_dim,
args.n_layers,
args.n_heads,
args.pf_dim,
args.dropout,
device,
embedding_flag = args.embedding_flag,
max_length = max_len_src,
mem_dim = args.mem_dim)
SRC_PAD_IDX = 0
model = CODE_SIM_ASM_Model(gnn, enc, args.hid_dim, SRC_PAD_IDX, device).to(device)
model.apply(initialize_weights)
optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = CircleLoss(gamma=args.gamma, m=args.margin)
train_gen_fun = dataset.get_pk_sample_generator_function(
split[0], args.p, args.k)
valid_gen_fun = dataset.get_pk_sample_generator_function(
split[1], args.p, args.k)
train_num_iters = args.train_epoch_size
valid_num_iters = args.valid_epoch_size
criterion.to(device)
args.summary = TrainingSummaryWriter(args.log_dir)
best_val = None
best_epoch = 0
print("start training")
for epoch in range(1, args.epoch_num + 1):
iterations(args, epoch, model, criterion, optimizer,
train_gen_fun(), train_num_iters, True , device)
best_val, best_epoch = validate(args, model, dataset, split[1], criterion,
epoch, best_val, best_epoch, device)
print(f'Epoch {epoch}')
if epoch == best_epoch and (args.checkpoint_path is not None):
output_path = os.path.join(args.checkpoint_path, f'model_sim_check.pt')
torch.save(model.state_dict(), output_path)
model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'model_sim_check.pt')))
test(args, model, dataset, split[2], device)