def virus_rdrp_identification()

in src/app/app.py [0:0]


def virus_rdrp_identification(input_fasta, threshold):
    if threshold < 0 or threshold > 1:
        threshold = 0.5
    SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    global global_args, global_model_dir, global_config_dir, global_config

    if global_args is None or global_model_dir is None or global_config_dir is None or global_config is None:
        global_args = Args()
        global_args.emb_dir = None
        global_args.pdb_dir = None
        global_args.truncation_seq_length = 4096
        global_args.dataset_name = "rdrp_40_extend"
        global_args.dataset_type = "protein"
        global_args.task_type = "binary_class"
        global_args.model_type = "sefn"
        global_args.time_str = "20230201140320"
        global_args.step = 100000
        global_args.threshold = threshold
        global_model_dir = "%s/../models/%s/%s/%s/%s/%s/%s" % (
            SCRIPT_DIR,
            global_args.dataset_name,
            global_args.dataset_type,
            global_args.task_type,
            global_args.model_type,
            global_args.time_str,
            global_args.step if global_args.step == "best" else "checkpoint-{}".format(global_args.step)
        )
        global_config_dir = "%s/../logs/%s/%s/%s/%s/%s" % (
            SCRIPT_DIR,
            global_args.dataset_name,
            global_args.dataset_type,
            global_args.task_type,
            global_args.model_type,
            global_args.time_str
        )
        # Step1: loading the model configuration
        global_config = load_args(global_config_dir)
        for key, value in global_config.items():
            try:
                if value.startswith("../"):
                    value = os.path.join(SCRIPT_DIR, value)
            except AttributeError:
                continue
            print(f'My item {value} is labelled {key}')
            global_config[key] = value
        print("-" * 25 + "config:" + "-" * 25)
        print(global_config)
        print("-" * 60)
        global_args.dataset_name = global_config["dataset_name"]
        global_args.dataset_type = global_config["dataset_type"]
        global_args.task_type = global_config["task_type"]
        global_args.model_type = global_config["model_type"]
        global_args.has_seq_encoder = global_config["has_seq_encoder"]
        global_args.has_struct_encoder = global_config["has_struct_encoder"]
        global_args.has_embedding_encoder = global_config["has_embedding_encoder"]
        global_args.subword = global_config["subword"]
        global_args.codes_file = global_config["codes_file"]
        global_args.input_mode = global_config["input_mode"]
        global_args.label_filepath = global_config["label_filepath"]
        if not os.path.exists(global_args.label_filepath):
            global_args.label_filepath = os.path.join(global_config_dir, "label.txt")
        global_args.output_dir = global_config["output_dir"]
        global_args.config_path = global_config["config_path"]
        global_args.do_lower_case = global_config["do_lower_case"]
        global_args.sigmoid = global_config["sigmoid"]
        global_args.loss_type = global_config["loss_type"]
        global_args.output_mode = global_config["output_mode"]
        global_args.seq_vocab_path = global_config["seq_vocab_path"]
        global_args.seq_pooling_type = global_config["seq_pooling_type"]
        global_args.seq_max_length = global_config["seq_max_length"]
        global_args.struct_vocab_path = global_config["struct_vocab_path"]
        global_args.struct_max_length = global_config["struct_max_length"]
        global_args.struct_pooling_type = global_config["struct_pooling_type"]
        global_args.trunc_type = global_config["trunc_type"]
        global_args.no_position_embeddings = global_config["no_position_embeddings"]
        global_args.no_token_type_embeddings = global_config["no_token_type_embeddings"]
        global_args.cmap_type = global_config["cmap_type"]
        global_args.cmap_type = float(global_config["cmap_thresh"])
        global_args.embedding_input_size = global_config["embedding_input_size"]
        global_args.embedding_pooling_type = global_config["embedding_pooling_type"]
        global_args.embedding_max_length = global_config["embedding_max_length"]
        global_args.embedding_type = global_config["embedding_type"]
        if global_args.task_type in ["multi-label", "multi_label"]:
            global_args.sigmoid = True
        elif global_args.task_type in ["binary-class", "binary_class"]:
            global_args.sigmoid = True
    global_args.threshold = threshold
    gpu_idx = available_gpu_id()
    if gpu_idx > -1:
        print("Use Device: GPU(%d)" % gpu_idx)
        device = torch.device("cuda:%d" % gpu_idx)
    else:
        print("Use Device: CPU")
        device = torch.device("cpu")
    global_args.device = device

    print("-" * 25 + "args:" + "-" * 25)
    print(global_args.__dict__.items())
    print("-" * 60)
    '''
    print("-" * 25 + "model_dir list:" + "-" * 25)
    print(os.listdir(global_model_dir))
    print("-" * 60)
    '''

    # Step2: loading the tokenizer and model
    config, subword, seq_tokenizer, struct_tokenizer, model, label_id_2_name, label_name_2_id = load_model(
        args=global_args,
        model_dir=global_model_dir
    )
    predict_func = None
    if global_args.task_type in ["multi-label", "multi_label"]:
        predict_func = predict_multi_label
    elif global_args.task_type in ["binary-class", "binary_class"]:
        predict_func = predict_binary_class
    elif global_args.task_type in ["multi-class", "multi_class"]:
        predict_func = predict_multi_class
    else:
        raise Exception("Not Support Task Type: %s" % global_args.task_type)

    input_seqs = []
    inputs = input_fasta.strip().split("\n")
    seq_id = None
    seq = None
    for line in inputs:
        if line.startswith(">"):
            if seq_id is not None and seq is not None:
                input_seqs.append([seq_id, seq])
            seq_id = line.strip()[1:]
            seq = ""
        else:
            seq += line.strip()
    if seq_id and seq:
        input_seqs.append([seq_id, seq])
    print("input_seqs:")
    print(input_seqs)

    # Step 3: prediction
    results = ""
    for row in input_seqs:
        row = [row[0], clean_seq(row[0], row[1])]
        res = predict_func(global_args, label_id_2_name, seq_tokenizer, subword, struct_tokenizer, model, row)
        results += "%s: [prob=%0.4f%%, label=%s, %s]<br><br>" % (
            row[0],
            res[0][2] * 100,
            "viral-RdRP" if res[0][2] == 1 else "non-viral-RdRP",
            "Yes" if res[0][2] == 1 else "No"
        )
    return results