def main()

in src/run.py [0:0]


def main():
    parser = argparse.ArgumentParser("Model Building for LucaProt")
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="the dataset dirpath."
    )
    parser.add_argument(
        "--separate_file",
        action="store_true",
        help="load the entire dataset using memory, only the names of the pdb and embedding files are listed in the train/dev/test.csv")
    parser.add_argument(
        "--filename_pattern",
        default=None,
        type=str,
        help="the dataset filename pattern,such as {}_with_pdb_emb.csv including train_with_pdb_emb.csv, dev_with_pdb_emb.csv, test_with_pdb_emb.csv in ${data_dir}")
    parser.add_argument(
        "--tfrecords",
        action="store_true",
        help="whether the dataset is in the tfrecords. When true, only the specified number of samples(${shuffle_queue_size}) are loaded into memory at once, the tfrecords must  be in ${data_dir}/tfrecords/train/xxx.tfrecords, ${data_dir}/tfrecords/dev/xxx.tfrecords and ${data_dir}/tfrecords/test/xxx.tfrecords. xxx.tfrecords may be 01-of-01.tfrecords(only including sequence)、01-of-01_emb.records(including sequence and structural embedding)、01-of-01_pdb_emb.records(including sequence, 3d-structure contact map, and structural embedding)")
    parser.add_argument(
        "--shuffle_queue_size",
        default=5000,
        type=int,
        help="how many samples are loaded into memory at once"
    )
    parser.add_argument(
        "--multi_tfrecords",
        action="store_true",
        help="whether exists multi-tfrecords"
    )
    parser.add_argument(
        "--dataset_name",
        default="rdrp_40_extend",
        type=str,
        required=True,
        help="dataset name"
    )
    parser.add_argument(
        "--dataset_type",
        default="protein",
        type=str,
        required=True,
        choices=["protein", "dna", "rna"],
        help="dataset type"
    )
    parser.add_argument(
        "--task_type",
        default="binary_class",
        type=str,
        required=True,
        choices=["multi_label", "multi_class", "binary_class"],
        help="task type"
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        choices=["sequence", "structure", "embedding", "sefn", "ssfn"],
        help="model type selected in the list: ['sequence-based', 'structure-based', 'structural embedding based', 'sequence and structural embedding based', 'sequence and structure based']"
    )
    parser.add_argument(
        "--subword",
        action="store_true",
        help="whether use subword-level for sequence"
    )
    parser.add_argument(
        "--codes_file",
        type=str,
        default="../subword/rdrp/protein_codes_rdrp_20000.txt",
        help="subword codes filepath"
    )

    parser.add_argument(
        "--input_mode",
        type=str,
        default="single",
        choices=["single", "concat", "independent"],
        help="the input operation"
    )
    parser.add_argument(
        "--label_type",
        default="rdrp",
        type=str,
        required=True,
        help="label type"
    )
    parser.add_argument(
        "--label_filepath",
        default=None,
        type=str,
        required=True,
        help="the label list filepath"
    )

    # for structure
    parser.add_argument(
        "--cmap_type",
        default=None,
        type=str,
        choices=["C_alpha", "C_bert"],
        help="the calculation type of 3d-structure contact map"
    )
    parser.add_argument(
        "--cmap_thresh",
        default=10.0,
        type=float,
        help="contact map threshold."
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="the output dirpath"
    )

    parser.add_argument(
        "--log_dir",
        default="./logs/",
        type=str,
        required=True,
        help="log dir."
    )
    parser.add_argument(
        "--tb_log_dir",
        default="./tb-logs/",
        type=str,
        required=True,
        help="tensorboard log dir."
    )

    # Other parameters
    parser.add_argument(
        "--config_path",
        default=None,
        type=str,
        required=True,
        help="the config filepath of the running model"
    )
    parser.add_argument(
        "--seq_vocab_path",
        default=None,
        type=str,
        help="sequence token vocab filepath"
    )
    parser.add_argument(
        "--struct_vocab_path",
        default=None,
        type=str,
        help="structure node token vocab filepath"
    )
    parser.add_argument(
        "--cache_dir",
        default=None,
        type=str,
        help="cache dirpath"
    )

    # sequence pooling_type
    parser.add_argument(
        "--seq_pooling_type",
        type=str,
        default=None,
        choices=["none", "sum", "max", "avg", "attention", "context_attention", "weighted_attention", "value_attention", "transformer"],
        help="pooling type for sequence encoder"
    )
    # structure pooling_type
    parser.add_argument(
        "--struct_pooling_type",
        type=str,
        default=None,
        choices=["sum", "max", "avg", "attention", "context_attention", "weighted_attention", "value_attention", "transformer"],
        help="pooling type for structure encoder"
    )
    # embedding pooling_type
    parser.add_argument(
        "--embedding_pooling_type",
        type=str,
        default=None,
        choices=["none", "sum", "max", "avg", "attention", "context_attention", "weighted_attention", "value_attention", "transformer"],
        help="pooling type for embedding encoder"
    )
    # activate function
    parser.add_argument(
        "--activate_func",
        type=str,
        default=None,
        choices=["tanh", "relu", "leakyrelu", "gelu"],
        help="activate function type after pooling"
    )

    parser.add_argument(
        "--do_train",
        action="store_true",
        help="whether to run training."
    )
    parser.add_argument(
        "--do_eval",
        action="store_true",
        help="whether to run eval on the dev set."
    )
    parser.add_argument(
        "--do_predict",
        action="store_true",
        help="whether to run predict on the test set."
    )
    parser.add_argument(
        "--evaluate_during_training",
        action="store_true",
        help="evaluation during training at each logging step."
    )
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="set this flag if you are using an uncased model."
    )

    parser.add_argument(
        "--per_gpu_train_batch_size",
        default=16,
        type=int,
        help="Batch size per GPU/CPU for training."
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=16,
        type=int,
        help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--learning_rate",
        default=1e-4,
        type=float,
        help="The initial learning rate for Adam."
    )
    parser.add_argument(
        "--weight_decay",
        default=0.0,
        type=float,
        help="Weight decay if we apply some."
    )
    parser.add_argument(
        "--adam_epsilon",
        default=1e-8,
        type=float,
        help="Epsilon for Adam optimizer."
    )
    parser.add_argument(
        "--max_grad_norm",
        default=1.0,
        type=float,
        help="Max gradient norm."
    )
    parser.add_argument(
        "--num_train_epochs",
        default=50,
        type=int,
        help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument(
        "--warmup_steps",
        default=0,
        type=int,
        help="Linear warmup over warmup_steps."
    )

    parser.add_argument(
        "--logging_steps",
        type=int,
        default=1000,
        help="Log every X updates steps."
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=1000,
        help="Save checkpoint every X updates steps.")

    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument(
        "--no_cuda",
        action="store_true",
        help="Avoid using CUDA when available"
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        help="Overwrite the content of the output directory"
    )
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="random seed for initialization"
    )

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
             "See details at https://nvidia.github.io/apex/amp.html"
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank"
    )

    # multi-label/ binary-class
    parser.add_argument(
        "--sigmoid",
        action="store_true",
        help="classifier add sigmoid if task_type is binary-class or multi-label"
    )

    # loss func
    parser.add_argument(
        "--loss_type",
        type=str,
        default="bce",
        choices=["focal_loss", "bce", "multilabel_cce", "asl", "cce"],
        help="loss type"
    )

    # which metric for model finalization selected
    parser.add_argument(
        "--max_metric_type",
        type=str,
        default="f1",
        required=True,
        choices=["acc", "jaccard", "prec", "recall", "f1", "fmax", "roc_auc", "pr_auc"],
        help="which metric for model selected"
    )

    # for BCE Loss
    parser.add_argument(
        "--pos_weight",
        type=float,
        default=40,
        help="positive weight for bce"
    )

    # for CE Loss
    parser.add_argument(
        "--weight",
        type=str,
        default=None,
        help="positive weight for multi-class"
    )

    # for focal Loss
    parser.add_argument(
        "--focal_loss_alpha",
        type=float,
        default=0.7,
        help="focal loss alpha"
    )
    parser.add_argument(
        "--focal_loss_gamma",
        type=float,
        default=2.0,
        help="focal loss gamma"
    )
    parser.add_argument(
        "--focal_loss_reduce",
        action="store_true",
        help="mean for one sample(default sum)"
    )

    # for asymmetric Loss
    parser.add_argument(
        "--asl_gamma_neg",
        type=float,
        default=4.0,
        help="negative gamma for asl"
    )
    parser.add_argument(
        "--asl_gamma_pos",
        type=float,
        default=1.0,
        help="positive gamma for asl"
    )

    # for sequence and structure graph node size(contact map shape)
    parser.add_argument(
        "--seq_max_length",
        default=2048,
        type=int,
        help="the length of input sequence more than max length will be truncated, shorter will be padded."
    )
    parser.add_argument(
        "--struct_max_length",
        default=2048,
        type=int,
        help="the length of input contact map more than max length will be truncated, shorter will be padded."
    )
    parser.add_argument(
        "--trunc_type",
        default="right",
        type=str,
        required=True,
        choices=["left", "right"],
        help="truncate type for whole input"
    )
    parser.add_argument(
        "--no_position_embeddings",
        action="store_true",
        help="Whether not to use position_embeddings"
    )
    parser.add_argument(
        "--no_token_type_embeddings",
        action="store_true",
        help="Whether not to use token_type_embeddings"
    )

    # for embedding input
    parser.add_argument(
        "--embedding_input_size",
        default=2560,
        type=int,
        help="the length of input embedding dim."
    )
    parser.add_argument(
        "--embedding_type",
        type=str,
        default="matrix",
        choices=[None, "contacts", "bos", "matrix"],
        help="the type of the structural embedding info"
    )
    parser.add_argument(
        "--embedding_max_length",
        default=2048,
        type=int,
        help="the length of input embedding more than max length will be truncated, shorter will be padded."
    )

    parser.add_argument(
        "--model_dirpath",
        default=None,
        type=str,
        help="load the trained model to continue training."
    )
    parser.add_argument(
        "--save_all",
        action="store_true",
        help="save all checkpoints during training"
    )
    parser.add_argument(
        "--delete_old",
        action="store_true",
        help="delete old checkpoint by the specific metric"
    )
    args = parser.parse_args()

    if args.model_type == "sequence":
        output_input_col_names = [args.dataset_type, "seq"]
        args.has_seq_encoder = True
        args.has_struct_encoder = False
        args.has_embedding_encoder = False
        args.cmap_type = None
        args.embedding_type = None
    elif args.model_type == "structure":
        output_input_col_names = [args.dataset_type, "structure"]
        args.has_seq_encoder = False
        args.has_struct_encoder = True
        args.has_embedding_encoder = False
        args.embedding_type = None
    elif args.model_type == "embedding":
        output_input_col_names = [args.dataset_type, "embedding"]
        args.has_seq_encoder = False
        args.has_struct_encoder = False
        args.has_embedding_encoder = True
        args.cmap_type = None
    elif args.model_type == "sefn":
        output_input_col_names = [args.dataset_type, "seq", "embedding"]
        args.has_seq_encoder = True
        args.has_struct_encoder = False
        args.has_embedding_encoder = True
        args.cmap_type = None
    elif args.model_type == "ssfn":
        output_input_col_names = [args.dataset_type, "seq", "structure"]
        args.has_seq_encoder = True
        args.has_struct_encoder = True
        args.has_embedding_encoder = False
        args.embedding_type = None
    else:
        raise Exception("Not support this model_type=%s" % args.model_type)

    # overwrite the output dir
    if os.path.exists(args.output_dir) \
            and os.listdir(args.output_dir) \
            and args.do_train \
            and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)
        )
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        os.makedirs(args.output_dir)
    # create the logs dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
    # create tensorboard logs dir
    if not os.path.exists(args.tb_log_dir):
        os.makedirs(args.tb_log_dir)

    # input types
    log_fp.write("Inputs:\n")
    log_fp.write("Input Name List: %s\n" % ",".join(output_input_col_names))
    log_fp.write("#" * 50 + "\n")

    # Setup CUDA, GPU & distributed training
    if args.no_cuda or not torch.cuda.is_available():
        device = torch.device("cpu")
        args.n_gpu = 0
    else:
        args.n_gpu = torch.cuda.device_count()
        if args.n_gpu > 1:
            torch.cuda.set_device(args.local_rank)
            device = torch.device("cuda", args.local_rank)
            torch.distributed.init_process_group(backend="nccl")
            if args.local_rank == 0:
                print('world size: %d' % torch.distributed.get_world_size())
        else:
            device = torch.device("cuda")

    args.device = device

    logger.info("#" * 50)
    logger.info(str(args))
    logger.info("#" * 50)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s-%(levelname)s-%(name)s | %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN
    )
    logger.warning(
        "Process rank: %d, device: %s, n_gpu: %d, distributed training: %r, 16-bits training: %r" % (
            args.local_rank,
            str(device),
            args.n_gpu,
            bool(args.local_rank != -1),
            args.fp16
        )
    )

    # Set seed
    set_seed(args)

    # Prepare task
    args.dataset_name = args.dataset_name.lower()

    # Data processor, different processors for different tasks
    processor = SequenceStructureProcessor(
        model_type=args.model_type,
        separate_file=args.separate_file,
        filename_pattern=args.filename_pattern
    )

    # the output type
    args.output_mode = args.task_type
    # For binary_class/multi_label tasks, the sigmoid needs to be added to the last layer
    if args.output_mode in ["multi_label", "multi-label", "binary_class", "binary-class"]:
        args.sigmoid = True

    # get label list
    label_list = processor.get_labels(
        label_filepath=args.label_filepath
    )
    num_labels = len(label_list)
    logger.info("#" * 25 + "Labels Num:" + "#" * 25)
    logger.info("Num Labels: %d" % num_labels)
    save_labels(os.path.join(args.log_dir, "label.txt"), label_list)

    # Get different task models according to the task type name
    args.model_type = args.model_type.lower()
    # Load config
    config_class = BertConfig
    config = config_class(**json.load(open(args.config_path, "r")))
    config.max_position_embeddings = int(args.seq_max_length)
    config.num_labels = num_labels
    config.embedding_pooling_type = args.embedding_pooling_type
    if args.activate_func:
        config.activate_func = args.activate_func
    if args.pos_weight:
        config.pos_weight = args.pos_weight
    # tokenization
    subword = None
    if args.has_seq_encoder:
        seq_tokenizer_class = BertTokenizer
        seq_tokenizer = seq_tokenizer_class(args.seq_vocab_path, do_lower_case=args.do_lower_case)
        config.vocab_size = seq_tokenizer.vocab_size
        if args.subword:
            bpe_codes_prot = codecs.open(args.codes_file)
            subword = BPE(bpe_codes_prot, merges=-1, separator='')
    else:
        seq_tokenizer_class = None
        seq_tokenizer = None

    if args.has_struct_encoder:
        struct_tokenizer_class = BertTokenizer
        struct_tokenizer = struct_tokenizer_class(args.struct_vocab_path, do_lower_case=args.do_lower_case)
        config.struct_vocab_size = struct_tokenizer.vocab_size
    else:
        struct_tokenizer_class = None
        struct_tokenizer = None

    # model class
    model_class = SequenceAndStructureFusionNetwork
    if args.model_dirpath and os.path.exists(args.model_dirpath):
        model = load_trained_model(config, args, model_class, args.model_dirpath)
    else:
        model = model_class(config, args)

    # output model hyperparameters in logger
    if len(config.id2label) > 10:
        str_config = copy.deepcopy(config)
        str_config.id2label = {}
        str_config.label2id = {}
    else:
        str_config = copy.deepcopy(config)
    log_fp.write("Model Config:\n %s\n" % str(str_config))
    log_fp.write("#" * 50 + "\n")
    log_fp.write("Mode Architecture:\n %s\n" % str(model))
    log_fp.write("#" * 50 + "\n")

    model.to(args.device)
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    # output training/evaluation hyperparameters in logger
    logger.info("====Training/Evaluation Parameters:=====")
    for attr, value in sorted(args.__dict__.items()):
        logger.info("\t{}={}".format(attr, value))
    logger.info("====Parameters End=====\n")

    args_dict = {}
    for attr, value in sorted(args.__dict__.items()):
        if attr != "device":
            args_dict[attr] = value

    log_fp.write(json.dumps(args_dict, ensure_ascii=False) + "\n")
    log_fp.write("#" * 50 + "\n")
    log_fp.write("num labels: %d\n" % num_labels)
    log_fp.write("#" * 50 + "\n")

    model_size_info = get_parameter_number(model)
    log_fp.write(json.dumps(model_size_info, ensure_ascii=False) + "\n")
    log_fp.write("#" * 50 + "\n")
    log_fp.flush()

    # Training
    max_metric_model_info = None
    if args.do_train:
        logger.info("++++++++++++Training+++++++++++++")
        global_step, tr_loss, max_metric_model_info = train(
            args,
            model,
            processor,
            seq_tokenizer,
            subword,
            struct_tokenizer=struct_tokenizer,
            log_fp=log_fp
        )
        logger.info("global_step = %s, average loss = %s", global_step, tr_loss)

    # Save
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        logger.info("++++++++++++Save Model+++++++++++++")
        # Create output directory if needed
        best_output_dir = os.path.join(args.output_dir, "best")
        global_step = max_metric_model_info["global_step"]
        prefix = "checkpoint-{}".format(global_step)
        shutil.copytree(os.path.join(args.output_dir, prefix), best_output_dir)
        logger.info("Saving model checkpoint to %s", best_output_dir)
        torch.save(args, os.path.join(best_output_dir, "training_args.bin"))
        save_labels(os.path.join(best_output_dir, "label.txt"), label_list)

    # Evaluate
    if args.do_eval and args.local_rank in [-1, 0]:
        logger.info("++++++++++++Validation+++++++++++++")
        log_fp.write("++++++++++++Validation+++++++++++++\n")
        global_step = max_metric_model_info["global_step"]
        logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
        log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
        prefix = "checkpoint-{}".format(global_step)
        checkpoint = os.path.join(args.output_dir, prefix)
        if seq_tokenizer is None and seq_tokenizer_class:
            seq_tokenizer = seq_tokenizer_class.from_pretrained(checkpoint, do_lower_case=args.do_lower_case)
        if struct_tokenizer_class and struct_tokenizer is None:
            struct_tokenizer = struct_tokenizer_class.from_pretrained(checkpoint, do_lower_case=args.do_lower_case)

        logger.info("checkpoint path: %s" % checkpoint)
        log_fp.write("checkpoint path: %s\n" % checkpoint)
        model = model_class.from_pretrained(checkpoint, args=args)
        model.to(args.device)
        result = evaluate(args, model, processor, seq_tokenizer, subword, struct_tokenizer, prefix=prefix, log_fp=log_fp)
        result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
        logger.info(json.dumps(result, ensure_ascii=False))
        log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")

    # Testing
    if args.do_predict and args.local_rank in [-1, 0]:
        logger.info("++++++++++++Testing+++++++++++++")
        log_fp.write("++++++++++++Testing+++++++++++++\n")
        global_step = max_metric_model_info["global_step"]
        logger.info("max %s global step: %d" % (args.max_metric_type, global_step))
        log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step))
        prefix = "checkpoint-{}".format(global_step)
        checkpoint = os.path.join(args.output_dir, prefix)
        if seq_tokenizer is None and seq_tokenizer_class:
            seq_tokenizer = seq_tokenizer_class.from_pretrained(checkpoint, do_lower_case=args.do_lower_case)
        if struct_tokenizer_class and struct_tokenizer is None:
            struct_tokenizer = struct_tokenizer_class.from_pretrained(checkpoint, do_lower_case=args.do_lower_case)
        logger.info("checkpoint path: %s" % checkpoint)
        log_fp.write("checkpoint path: %s\n" % checkpoint)
        model = model_class.from_pretrained(checkpoint, args=args)
        model.to(args.device)
        pred, true, result = predict(args, model, processor, seq_tokenizer, subword, struct_tokenizer, prefix=prefix, log_fp=log_fp)
        result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items())
        logger.info(json.dumps(result, ensure_ascii=False))
        log_fp.write(json.dumps(result, ensure_ascii=False) + "\n")

    # close fp
    if args.local_rank in [-1, 0] and log_fp:
        log_fp.close()

    if args.n_gpu > 1:
        torch.distributed.barrier()