def main()

in src/predict_one_sample.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Prediction RdRP")
    # for llm
    parser.add_argument(
        "--torch_hub_dir",
        default=None,
        type=str,
        help="set the torch hub dir path for saving pretrained model(default:~/.cache/torch/hub)"
    )
    # for input
    parser.add_argument(
        "--protein_id",
        default=None,
        type=str,
        required=True,
        help="the protein id"
    )
    parser.add_argument(
        "--sequence",
        default=None,
        type=str,
        required=True,
        help="the protein sequence"
    )
    parser.add_argument(
        "--truncation_seq_length",
        default=4096,
        type=int,
        required=True,
        help="truncation seq length"
    )
    parser.add_argument(
        "--emb_dir",
        default=None,
        type=str,
        help="the llm embedding save dir. default: None"
    )
    parser.add_argument(
        "--pdb_dir",
        default=None,
        type=str,
        help="the 3d-structure pdb save dir. default: None"
    )

    # for trained checkpoint
    parser.add_argument(
        "--chain",
        default=None,
        type=str,
        help="pdb chain for contact map computing"
    )
    parser.add_argument(
        "--dataset_name",
        default="rdrp_40_extend",
        type=str,
        required=True,
        help="the dataset name for model building."
    )
    parser.add_argument(
        "--dataset_type",
        default="protein",
        type=str,
        required=True,
        help="the dataset type for model building."
    )
    parser.add_argument(
        "--task_type",
        default=None,
        type=str,
        required=True,
        choices=["multi_label", "multi_class", "binary_class"],
        help="the task type for model building."
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="model type."
    )
    parser.add_argument(
        "--time_str",
        default=None,
        type=str,
        required=True,
        help="the running time string(yyyymmddHimiss) of trained checkpoint building."
    )
    parser.add_argument(
        "--step",
        default=None,
        type=str,
        required=True,
        help="the training global step of model finalization."
    )
    parser.add_argument(
        "--threshold",
        default=0.5,
        type=float,
        help="sigmoid threshold for binary-class or multi-label classification, None for multi-class classification, defualt: 0.5."
    )
    parser.add_argument(
        "--gpu_id",
        default=None,
        type=int,
        help="the used gpu index, -1 for cpu"
    )
    input_args = parser.parse_args()
    return input_args