in src/baselines/dnn.py [0:0]
def run():
args = main()
logging.basicConfig(format="%(asctime)s-%(levelname)s-%(name)s | %(message)s", datefmt="%Y/%m/%d %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
# overwrite the output dir
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
else:
os.makedirs(args.output_dir)
# create logger dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
# create tensorboard logger dir
if not os.path.exists(args.tb_log_dir):
os.makedirs(args.tb_log_dir)
config_class = BertConfig
config = config_class(**json.load(open(args.config_path, "r")))
args2config(args, config)
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 1 if not args.no_cuda else 0
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = torch.cuda.device_count()
args.device = device
model = DNN(config, args)
model.to(args.device)
train_dataset, label_list = load_dataset(args, "train")
dev_dataset, _ = load_dataset(args, "dev")
test_dataset, _ = load_dataset(args, "test")
config.num_labels = len(label_list)
args.num_labels = config.num_labels
log_fp.write("Model Config:\n %s\n" % str(config))
log_fp.write("#" * 50 + "\n")
log_fp.write("Mode Architecture:\n %s\n" % str(model))
log_fp.write("#" * 50 + "\n")
log_fp.write("num labels: %d\n" % args.num_labels)
log_fp.write("#" * 50 + "\n")
max_metric_model_info = None
if args.do_train:
logger.info("++++++++++++Training+++++++++++++")
global_step, tr_loss, max_metric_model_info = trainer(args, model, train_dataset, dev_dataset, test_dataset, log_fp=log_fp)
logger.info("global_step = %s, average loss = %s", global_step, tr_loss)
# save
if args.do_train:
logger.info("++++++++++++Save Model+++++++++++++")
# Create output directory if needed
best_output_dir = os.path.join(args.output_dir, "best")
global_step = max_metric_model_info["global_step"]
prefix = "checkpoint-{}".format(global_step)
shutil.copytree(os.path.join(args.output_dir, prefix), best_output_dir)
logger.info("Saving model checkpoint to %s", best_output_dir)
torch.save(args, os.path.join(best_output_dir, "training_args.bin"))
save_labels(os.path.join(best_output_dir, "label.txt"), label_list)
# evaluate
if args.do_eval and args.local_rank in [-1, 0]:
logger.info("++++++++++++Validation+++++++++++++")
log_fp.write("++++++++++++Validation+++++++++++++\n")
global_step = max_metric_model_info["global_step"]
logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
logger.info("checkpoint path: %s" % checkpoint)
log_fp.write("checkpoint path: %s\n" % checkpoint)
model = torch.load(os.path.join(checkpoint, "dnn.pkl"))
model.to(args.device)
model.eval()
result = evaluate(args, model, dev_dataset, prefix=prefix, log_fp=log_fp)
result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
logger.info(json.dumps(result, ensure_ascii=False))
log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
# Testing
if args.do_predict and args.local_rank in [-1, 0]:
logger.info("++++++++++++Testing+++++++++++++")
log_fp.write("++++++++++++Testing+++++++++++++\n")
global_step = max_metric_model_info["global_step"]
logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
logger.info("checkpoint path: %s" % checkpoint)
log_fp.write("checkpoint path: %s\n" % checkpoint)
model = torch.load(os.path.join(checkpoint, "dnn.pkl"))
model.to(args.device)
model.eval()
pred, true, result = predict(args, model, test_dataset, prefix=prefix, log_fp=log_fp)
result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
logger.info(json.dumps(result, ensure_ascii=False))
log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
log_fp.close()