in src/deep_baselines/run.py [0:0]
def run():
args = main()
logging.basicConfig(format="%(asctime)s-%(levelname)s-%(name)s | %(message)s", datefmt="%Y/%m/%d %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
# overwrite the output dir
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
else:
if os.path.exists(args.output_dir):
shutil.rmtree(args.output_dir)
os.makedirs(args.output_dir)
# create logger dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
# create tensorboard logger dir
if not os.path.exists(args.tb_log_dir):
os.makedirs(args.tb_log_dir)
config_class = BertConfig
config = config_class(**json.load(open(args.config_path, "r")))
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 1 if not args.no_cuda else 0
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = torch.cuda.device_count()
args.device = device
int_2_token, token_2_int = load_vocab(args.seq_vocab_path)
config.vocab_size = len(token_2_int)
label_list = get_labels(label_filepath=args.label_filepath)
save_labels(os.path.join(args.log_dir, "label.txt"), label_list)
args.num_labels = len(label_list)
args2config(args, config)
if args.model_type in ["CHEER-CatWCNN", "CHEER-WDCNN", "CHEER-WCNN"]:
if args.channel_in:
args.seq_max_length = (args.seq_max_length//args.channel_in) * args.channel_in
else:
args.seq_max_length = (args.seq_max_length//config.channel_in) * config.channel_in
if args.model_type == "CHEER-CatWCNN":
model = CatWCNN(config, args)
elif args.model_type == "CHEER-WDCNN":
model = WDCNN(config, args)
elif args.model_type == "CHEER-WCNN":
model = WCNN(config, args)
elif args.model_type == "VirHunter":
model = VirHunter(config, args)
elif args.model_type == "Virtifier":
model = Virtifier(config, args)
elif args.model_type == "VirSeeker":
model = VirSeeker(config, args)
else:
raise Exception("not support model type: %s" % args.model_type)
encode_func = None
encode_func_args = {"max_len": args.seq_max_length, "vocab": token_2_int, "trunc_type": args.trunc_type}
if args.model_type in ["CHEER-CatWCNN", "CHEER-WDCNN", "CHEER-WCNN"]:
encode_func = cheer_seq_encode
encode_func_args["channel_in"] = args.channel_in
elif args.model_type == "VirHunter":
encode_func = virhunter_seq_encode
elif args.model_type == "Virtifier":
encode_func = virtifier_seq_encode
elif args.model_type == "VirSeeker":
encode_func = virseeker_seq_encode
args_dict = {}
for attr, value in sorted(args.__dict__.items()):
if attr != "device":
args_dict[attr] = value
log_fp.write(json.dumps(args_dict, ensure_ascii=False) + "\n")
model.to(args.device)
train_dataset, label_list = load_dataset(args, "train", encode_func, encode_func_args)
dev_dataset, _ = load_dataset(args, "dev", encode_func, encode_func_args)
test_dataset, _ = load_dataset(args, "test", encode_func, encode_func_args)
log_fp.write("Model Config:\n %s\n" % str(config))
log_fp.write("#" * 50 + "\n")
log_fp.write("Mode Architecture:\n %s\n" % str(model))
log_fp.write("#" * 50 + "\n")
log_fp.write("num labels: %d\n" % args.num_labels)
log_fp.write("#" * 50 + "\n")
max_metric_model_info = None
if args.do_train:
logger.info("++++++++++++Training+++++++++++++")
global_step, tr_loss, max_metric_model_info = trainer(args, model, train_dataset, dev_dataset, test_dataset, log_fp=log_fp)
logger.info("global_step = %s, average loss = %s", global_step, tr_loss)
# save
if args.do_train:
logger.info("++++++++++++Save Model+++++++++++++")
# Create output directory if needed
best_output_dir = os.path.join(args.output_dir, "best")
global_step = max_metric_model_info["global_step"]
prefix = "checkpoint-{}".format(global_step)
shutil.copytree(os.path.join(args.output_dir, prefix), best_output_dir)
logger.info("Saving model checkpoint to %s", best_output_dir)
torch.save(args, os.path.join(best_output_dir, "training_args.bin"))
save_labels(os.path.join(best_output_dir, "label.txt"), label_list)
# evaluate
if args.do_eval and args.local_rank in [-1, 0]:
logger.info("++++++++++++Validation+++++++++++++")
log_fp.write("++++++++++++Validation+++++++++++++\n")
global_step = max_metric_model_info["global_step"]
logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
logger.info("checkpoint path: %s" % checkpoint)
log_fp.write("checkpoint path: %s\n" % checkpoint)
model = torch.load(os.path.join(checkpoint, "%s.pkl" % args.model_type))
model.to(args.device)
model.eval()
result = evaluate(args, model, dev_dataset, prefix=prefix, log_fp=log_fp)
result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
logger.info(json.dumps(result, ensure_ascii=False))
log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
# Testing
if args.do_predict and args.local_rank in [-1, 0]:
logger.info("++++++++++++Testing+++++++++++++")
log_fp.write("++++++++++++Testing+++++++++++++\n")
global_step = max_metric_model_info["global_step"]
logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
logger.info("checkpoint path: %s" % checkpoint)
log_fp.write("checkpoint path: %s\n" % checkpoint)
model = torch.load(os.path.join(checkpoint, "%s.pkl" % args.model_type))
model.to(args.device)
model.eval()
pred, true, result = predict(args, model, test_dataset, prefix=prefix, log_fp=log_fp)
result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
logger.info(json.dumps(result, ensure_ascii=False))
log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
log_fp.close()