in src/app/app.py [0:0]
def virus_rdrp_identification(input_fasta, threshold):
if threshold < 0 or threshold > 1:
threshold = 0.5
SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
global global_args, global_model_dir, global_config_dir, global_config
if global_args is None or global_model_dir is None or global_config_dir is None or global_config is None:
global_args = Args()
global_args.emb_dir = None
global_args.pdb_dir = None
global_args.truncation_seq_length = 4096
global_args.dataset_name = "rdrp_40_extend"
global_args.dataset_type = "protein"
global_args.task_type = "binary_class"
global_args.model_type = "sefn"
global_args.time_str = "20230201140320"
global_args.step = 100000
global_args.threshold = threshold
global_model_dir = "%s/../models/%s/%s/%s/%s/%s/%s" % (
SCRIPT_DIR,
global_args.dataset_name,
global_args.dataset_type,
global_args.task_type,
global_args.model_type,
global_args.time_str,
global_args.step if global_args.step == "best" else "checkpoint-{}".format(global_args.step)
)
global_config_dir = "%s/../logs/%s/%s/%s/%s/%s" % (
SCRIPT_DIR,
global_args.dataset_name,
global_args.dataset_type,
global_args.task_type,
global_args.model_type,
global_args.time_str
)
# Step1: loading the model configuration
global_config = load_args(global_config_dir)
for key, value in global_config.items():
try:
if value.startswith("../"):
value = os.path.join(SCRIPT_DIR, value)
except AttributeError:
continue
print(f'My item {value} is labelled {key}')
global_config[key] = value
print("-" * 25 + "config:" + "-" * 25)
print(global_config)
print("-" * 60)
global_args.dataset_name = global_config["dataset_name"]
global_args.dataset_type = global_config["dataset_type"]
global_args.task_type = global_config["task_type"]
global_args.model_type = global_config["model_type"]
global_args.has_seq_encoder = global_config["has_seq_encoder"]
global_args.has_struct_encoder = global_config["has_struct_encoder"]
global_args.has_embedding_encoder = global_config["has_embedding_encoder"]
global_args.subword = global_config["subword"]
global_args.codes_file = global_config["codes_file"]
global_args.input_mode = global_config["input_mode"]
global_args.label_filepath = global_config["label_filepath"]
if not os.path.exists(global_args.label_filepath):
global_args.label_filepath = os.path.join(global_config_dir, "label.txt")
global_args.output_dir = global_config["output_dir"]
global_args.config_path = global_config["config_path"]
global_args.do_lower_case = global_config["do_lower_case"]
global_args.sigmoid = global_config["sigmoid"]
global_args.loss_type = global_config["loss_type"]
global_args.output_mode = global_config["output_mode"]
global_args.seq_vocab_path = global_config["seq_vocab_path"]
global_args.seq_pooling_type = global_config["seq_pooling_type"]
global_args.seq_max_length = global_config["seq_max_length"]
global_args.struct_vocab_path = global_config["struct_vocab_path"]
global_args.struct_max_length = global_config["struct_max_length"]
global_args.struct_pooling_type = global_config["struct_pooling_type"]
global_args.trunc_type = global_config["trunc_type"]
global_args.no_position_embeddings = global_config["no_position_embeddings"]
global_args.no_token_type_embeddings = global_config["no_token_type_embeddings"]
global_args.cmap_type = global_config["cmap_type"]
global_args.cmap_type = float(global_config["cmap_thresh"])
global_args.embedding_input_size = global_config["embedding_input_size"]
global_args.embedding_pooling_type = global_config["embedding_pooling_type"]
global_args.embedding_max_length = global_config["embedding_max_length"]
global_args.embedding_type = global_config["embedding_type"]
if global_args.task_type in ["multi-label", "multi_label"]:
global_args.sigmoid = True
elif global_args.task_type in ["binary-class", "binary_class"]:
global_args.sigmoid = True
global_args.threshold = threshold
gpu_idx = available_gpu_id()
if gpu_idx > -1:
print("Use Device: GPU(%d)" % gpu_idx)
device = torch.device("cuda:%d" % gpu_idx)
else:
print("Use Device: CPU")
device = torch.device("cpu")
global_args.device = device
print("-" * 25 + "args:" + "-" * 25)
print(global_args.__dict__.items())
print("-" * 60)
'''
print("-" * 25 + "model_dir list:" + "-" * 25)
print(os.listdir(global_model_dir))
print("-" * 60)
'''
# Step2: loading the tokenizer and model
config, subword, seq_tokenizer, struct_tokenizer, model, label_id_2_name, label_name_2_id = load_model(
args=global_args,
model_dir=global_model_dir
)
predict_func = None
if global_args.task_type in ["multi-label", "multi_label"]:
predict_func = predict_multi_label
elif global_args.task_type in ["binary-class", "binary_class"]:
predict_func = predict_binary_class
elif global_args.task_type in ["multi-class", "multi_class"]:
predict_func = predict_multi_class
else:
raise Exception("Not Support Task Type: %s" % global_args.task_type)
input_seqs = []
inputs = input_fasta.strip().split("\n")
seq_id = None
seq = None
for line in inputs:
if line.startswith(">"):
if seq_id is not None and seq is not None:
input_seqs.append([seq_id, seq])
seq_id = line.strip()[1:]
seq = ""
else:
seq += line.strip()
if seq_id and seq:
input_seqs.append([seq_id, seq])
print("input_seqs:")
print(input_seqs)
# Step 3: prediction
results = ""
for row in input_seqs:
row = [row[0], clean_seq(row[0], row[1])]
res = predict_func(global_args, label_id_2_name, seq_tokenizer, subword, struct_tokenizer, model, row)
results += "%s: [prob=%0.4f%%, label=%s, %s]<br><br>" % (
row[0],
res[0][2] * 100,
"viral-RdRP" if res[0][2] == 1 else "non-viral-RdRP",
"Yes" if res[0][2] == 1 else "No"
)
return results