train.py (86 lines of code) (raw):

import time import os import torch import numpy as np from torch.utils.tensorboard import SummaryWriter from datetime import datetime import argparse from model import StrScorer from utils import softmax, unpackbits, update_buffers, write_stats, RpBuffer, BinaryDataLoader def setup_model_artifacts(artifactpath, model_name): def try_mkdir(dirname): try: os.mkdir(dirname) except: pass try_mkdir(f"{artifactpath}/") try_mkdir(f"{artifactpath}/{model_name}") if __name__ == "__main__": # argparse stuff parser = argparse.ArgumentParser(description="Train a signature generation model.") parser.add_argument("--uripath", help="path for training sample uris", type=str, required=True) parser.add_argument("--cuda", help="use cuda?") parser.add_argument("--artifactpath", help="path for tensorboard/model artifacts", type=str, default="model_artifacts") parser.add_argument("--lr", help="learning rate", type=float, default=1e-3) parser.add_argument("--kernelsz", help="convolution kernel size", type=int, default=4) parser.add_argument("--embed_dim", help="embedding size", type=int, default=16) parser.add_argument("--seqlen", help="binary chunk size", type=int, default=16384) parser.add_argument("--topk_count", help="number of max's to backprop through", type=int, default=10) parser.add_argument("--architecture", help="model architecture", type=str, default="32,64,128,192,256,512") args = parser.parse_args() uripath = args.uripath artifactpath = args.artifactpath lr = args.lr n = args.kernelsz embed_dim = args.embed_dim seqlen = args.seqlen topk_count = args.topk_count cuda = args.cuda architecture = [int(h) for h in args.architecture.split(",")] nlen = 1 + (n-1) * len(architecture) # give model a timestamp; set up artifacts dirs model_name = "%08x" % int(time.time()) setup_model_artifacts(artifactpath, model_name) # init tensorbaord writer writer = SummaryWriter(log_dir=f"{artifactpath}/{model_name}/tensorboard/") # init detections file (used to eyeball sigs while model trains) det_file = open(f"{artifactpath}/{model_name}/det_file", "w") # init model model = StrScorer(architecture=architecture, lr=lr, n=n, embed_dim=embed_dim, topk_count=topk_count, cuda=cuda) # init replay buffer rpbuffer = RpBuffer(timeout=3600) niter = 0 warmup = 0 while True: # create a new data loader dl = torch.utils.data.DataLoader(BinaryDataLoader(uripath, balance=True), batch_size=32, shuffle=True, num_workers=8) # one-time filling of replay buffer until we have ~3200 binaries to sample from if warmup == 0: for sample_bytes, shas, labels in dl: print("warmup", warmup) full_yps, full_labels = update_buffers(model, sample_bytes, shas, labels, seqlen, nlen, rpbuffer, det_file) warmup += 1 if warmup > 100: break # grab samples from dataloader for sample_bytes, shas, labels in dl: # cull replay buffer rpbuffer.cull_blocks() # get most malicious looking chunks from samples full_yps, full_labels = update_buffers(model, sample_bytes, shas, labels, seqlen, nlen, rpbuffer, det_file) write_stats(writer, niter, full_yps, full_labels) niter += 1 # save model every 1000 iterations if niter % 1000 == 0: torch.save(model, f"{artifactpath}/{model_name}/model.mdl") # sample from replay buffer 40 times (arbitrary number) for sample_i in range(40): # grab a batch of 32 samples samples = rpbuffer.get_samples(32) # sample chunks from binaries based on scores of each chunk batch_x = [] batch_y = torch.zeros(32, 1) block_sel = [] sha_samples = [] for i, s in enumerate(samples): sha_sample = s[1] data = s[2] y = s[3] yps = s[4] # softmax-sampling of chunks. chunks with higher scores relative # to other chunks are more likely to be selected. yps_softmax = softmax(yps) # select a chunk weighted by softmax probability block_idx = np.random.choice(len(yps), p=yps_softmax) # append chunk, label info, block index, and sha to update score later batch_x.append(data[block_idx]) batch_y[i] = y block_sel.append(block_idx) sha_samples.append(sha_sample) batch_x = torch.from_numpy(unpackbits(batch_x, seqlen)) # fit model yps, loss = model.fit(batch_x, batch_y) yps = yps.flatten() # update scores in replay buffer with new score of chunk for i in range(batch_x.shape[0]): rpbuffer.update_score(sha_samples[i], block_sel[i], yps[i])