def run()

in src/deep_baselines/run.py [0:0]


def run():
    args = main()
    logging.basicConfig(format="%(asctime)s-%(levelname)s-%(name)s | %(message)s", datefmt="%Y/%m/%d %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # overwrite the output dir
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        os.makedirs(args.output_dir)
    # create logger dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
    # create tensorboard logger dir
    if not os.path.exists(args.tb_log_dir):
        os.makedirs(args.tb_log_dir)

    config_class = BertConfig
    config = config_class(**json.load(open(args.config_path, "r")))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 1 if not args.no_cuda else 0
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = torch.cuda.device_count()
    args.device = device

    int_2_token, token_2_int = load_vocab(args.seq_vocab_path)
    config.vocab_size = len(token_2_int)

    label_list = get_labels(label_filepath=args.label_filepath)
    save_labels(os.path.join(args.log_dir, "label.txt"), label_list)
    args.num_labels = len(label_list)

    args2config(args, config)

    if args.model_type in ["CHEER-CatWCNN", "CHEER-WDCNN", "CHEER-WCNN"]:
        if args.channel_in:
            args.seq_max_length = (args.seq_max_length//args.channel_in) * args.channel_in
        else:
            args.seq_max_length = (args.seq_max_length//config.channel_in) * config.channel_in

    if args.model_type == "CHEER-CatWCNN":
        model = CatWCNN(config, args)
    elif args.model_type == "CHEER-WDCNN":
        model = WDCNN(config, args)
    elif args.model_type == "CHEER-WCNN":
        model = WCNN(config, args)
    elif args.model_type == "VirHunter":
        model = VirHunter(config, args)
    elif args.model_type == "Virtifier":
        model = Virtifier(config, args)
    elif args.model_type == "VirSeeker":
        model = VirSeeker(config, args)
    else:
        raise Exception("not support model type: %s" % args.model_type)

    encode_func = None
    encode_func_args = {"max_len": args.seq_max_length, "vocab": token_2_int, "trunc_type": args.trunc_type}
    if args.model_type in ["CHEER-CatWCNN", "CHEER-WDCNN",  "CHEER-WCNN"]:
        encode_func = cheer_seq_encode
        encode_func_args["channel_in"] = args.channel_in
    elif args.model_type == "VirHunter":
        encode_func = virhunter_seq_encode
    elif args.model_type == "Virtifier":
        encode_func = virtifier_seq_encode
    elif args.model_type == "VirSeeker":
        encode_func = virseeker_seq_encode
    args_dict = {}
    for attr, value in sorted(args.__dict__.items()):
        if attr != "device":
            args_dict[attr] = value

    log_fp.write(json.dumps(args_dict, ensure_ascii=False) + "\n")

    model.to(args.device)

    train_dataset, label_list = load_dataset(args, "train", encode_func, encode_func_args)

    dev_dataset, _ = load_dataset(args, "dev", encode_func, encode_func_args)

    test_dataset, _ = load_dataset(args, "test", encode_func, encode_func_args)

    log_fp.write("Model Config:\n %s\n" % str(config))
    log_fp.write("#" * 50 + "\n")
    log_fp.write("Mode Architecture:\n %s\n" % str(model))
    log_fp.write("#" * 50 + "\n")
    log_fp.write("num labels: %d\n" % args.num_labels)
    log_fp.write("#" * 50 + "\n")

    max_metric_model_info = None
    if args.do_train:
        logger.info("++++++++++++Training+++++++++++++")
        global_step, tr_loss, max_metric_model_info = trainer(args, model, train_dataset, dev_dataset, test_dataset, log_fp=log_fp)
        logger.info("global_step = %s, average loss = %s", global_step, tr_loss)

    #  save
    if args.do_train:
        logger.info("++++++++++++Save Model+++++++++++++")
        # Create output directory if needed

        best_output_dir = os.path.join(args.output_dir, "best")

        global_step = max_metric_model_info["global_step"]
        prefix = "checkpoint-{}".format(global_step)
        shutil.copytree(os.path.join(args.output_dir, prefix), best_output_dir)
        logger.info("Saving model checkpoint to %s", best_output_dir)
        torch.save(args, os.path.join(best_output_dir, "training_args.bin"))
        save_labels(os.path.join(best_output_dir, "label.txt"), label_list)

    # evaluate
    if args.do_eval and args.local_rank in [-1, 0]:
        logger.info("++++++++++++Validation+++++++++++++")
        log_fp.write("++++++++++++Validation+++++++++++++\n")
        global_step = max_metric_model_info["global_step"]
        logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
        log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
        prefix = "checkpoint-{}".format(global_step)
        checkpoint = os.path.join(args.output_dir, prefix)
        logger.info("checkpoint path: %s" % checkpoint)
        log_fp.write("checkpoint path: %s\n" % checkpoint)
        model = torch.load(os.path.join(checkpoint, "%s.pkl" % args.model_type))
        model.to(args.device)
        model.eval()

        result = evaluate(args, model, dev_dataset, prefix=prefix, log_fp=log_fp)
        result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
        logger.info(json.dumps(result, ensure_ascii=False))
        log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")

    # Testing
    if args.do_predict and args.local_rank in [-1, 0]:
        logger.info("++++++++++++Testing+++++++++++++")
        log_fp.write("++++++++++++Testing+++++++++++++\n")
        global_step = max_metric_model_info["global_step"]
        logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
        log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
        prefix = "checkpoint-{}".format(global_step)
        checkpoint = os.path.join(args.output_dir, prefix)
        logger.info("checkpoint path: %s" % checkpoint)
        log_fp.write("checkpoint path: %s\n" % checkpoint)
        model = torch.load(os.path.join(checkpoint, "%s.pkl" % args.model_type))
        model.to(args.device)
        model.eval()
        pred, true, result = predict(args, model, test_dataset, prefix=prefix, log_fp=log_fp)
        result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
        logger.info(json.dumps(result, ensure_ascii=False))
        log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
    log_fp.close()