in baseline_model/run_vulnerability_detection.py [0:0]
def main():
title = 'vulnerability-detection'
argParser = config.get_arg_parser(title)
args = argParser.parse_args()
if not os.path.exists(args.cache_path):
os.makedirs(args.cache_path)
max_len_src = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
with open(os.path.join(args.dataset_dir, 'vocab_asm.pkl'), 'rb') as f:
vocab_asm = pickle.load(f)
with open(os.path.join(args.dataset_dir, 'dataset_asm.pkl'), 'rb') as f:
dataset_asm = pickle.load(f)
with open(os.path.join(args.dataset_dir, 'tgt_asm.pkl'), 'rb') as f:
tgt_asm = pickle.load(f)
SEED=1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
exp_list = []
Graph = RawField()
EDGE = RawField()
Nodes_num = RawField()
TRG = RawField()
for i in range(0, len(dataset_asm)):
nodes_asm,edges_asm = dataset_asm[i]
if len(nodes_asm) > args.max_tolerate_len:
continue
g = dgl.DGLGraph((edges_asm[0],edges_asm[2]))
src_len = len(nodes_asm)
idmap = range(0, src_len)
g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
g.ndata['annotation'] = torch.tensor(nodes_asm, dtype=torch.long)
g.edata['type'] = torch.tensor(edges_asm[1])
tgt = tgt_asm[i]
exp = Example.fromlist([g, edges_asm, tgt, src_len],fields =[('graph',Graph), ('edge', EDGE), ('trg', TRG), ('nodes_num',Nodes_num)] )
exp_list.append(exp)
data_sets = Dataset(exp_list, fields = [('graph',Graph), ('edge',EDGE), ('trg', TRG), ('nodes_num', Nodes_num)])
trn, tst, vld = data_sets.split([0.9,0.08,0.02])
max_len_src = args.max_tolerate_len
print("Number of training examples: %d" % (len(trn.examples)))
print("Number of validation examples: %d" % (len(vld.examples)))
print("Number of testing examples: %d" % (len(tst.examples)))
args.summary = TrainingSummaryWriter(args.log_dir)
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(trn, vld, tst),
batch_size = args.batch_size,
sort_key = None,
sort_within_batch = False,
sort = False,
device = device)
gnn = Graph_NN( annotation_size = len(vocab_asm),
out_feats = args.hid_dim,
n_steps = args.n_gnn_layers,
device = device,
gnn_type='ggnn',
tok_embedding=2,
residual=False
)
enc = Encoder(
None,
args.hid_dim,
args.n_layers,
args.n_heads,
args.pf_dim,
args.dropout,
device,
mem_dim=args.mem_dim,
embedding_flag = args.embedding_flag,
max_length = max_len_src)
SRC_PAD_IDX = 0
model = VUL_DETECT_ASM_Model(gnn, enc, SRC_PAD_IDX, device, args.hid_dim, 2).to(device)
model.apply(initialize_weights)
criterion = torch.nn.CrossEntropyLoss() if args.one_hot_label else torch.nn.BCELoss()
optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion.to(device)
best_val = None
best_epoch = 0
print("start training")
best_valid_loss = float('inf')
if args.training and not args.eval:
for epoch in range(args.epoch_num):
all_preds = []
all_labels = []
all_loss = []
for i, batch in init_tqdm(enumerate(train_iterator), 'train' , log=args.log_dir):
batch_graph_tmp = preprocessing_batch(max(batch.nodes_num), batch.graph, batch.edge, device)
batch_graph = dgl.batch(batch.graph).to(device)
labels = torch.tensor(batch.trg)
loss, preds = train_eval(batch_graph, labels, batch_graph_tmp, model, device, optimizer, criterion, train=True)
all_preds += [preds]
all_labels += [labels]
all_loss += [loss]
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
metrics = report(all_labels, all_preds)
loss = sum(all_loss).item()/len(all_loss)
print('==> Epoch {}, Train Loss {:.4f}\t'.format(epoch, loss) + report_to_str(metrics, keys=True))
all_preds = []
all_labels = []
all_loss = []
for i, batch in enumerate(valid_iterator):
batch_graph_tmp = preprocessing_batch(max(batch.nodes_num), batch.graph, batch.edge, device)
batch_graph = dgl.batch(batch.graph)
labels = torch.tensor(batch.trg)
loss, preds = train_eval(batch_graph, labels, batch_graph_tmp, model, device, optimizer, criterion, train=False)
all_preds += [preds]
all_labels += [labels]
all_loss += [loss]
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
metrics = report(all_labels, all_preds)
loss = sum(all_loss).item() / len(all_loss)
print('==> Epoch {}, Valid Loss {:.4f}\t'.format(epoch, loss) + report_to_str(metrics, keys=True))
if loss < best_valid_loss and (args.checkpoint_path is not None):
best_valid_loss = loss
torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'model_vul_detection.pt'))