src/deep_baselines/virhunter.py (212 lines of code) (raw):

#!/usr/bin/env python # encoding: utf-8 ''' *Copyright (c) 2023, Alibaba Group; *Licensed under the Apache License, Version 2.0 (the "License"); *you may not use this file except in compliance with the License. *You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 *Unless required by applicable law or agreed to in writing, software *distributed under the License is distributed on an "AS IS" BASIS, *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *See the License for the specific language governing permissions and *limitations under the License. @author: Hey @email: sanyuan.**alibaba-inc.com @tel: 137****6540 @datetime: 2023/2/28 16:36 @project: DeepProtFunc @file: virhunter @desc: VirHunter: A Deep Learning-Based Method for Detection of Novel RNA Viruses in Plant Sequencing Data ''' import logging import sys import torch from torch.nn.functional import one_hot import torch.nn.functional from torch import nn from torch.nn import BCEWithLogitsLoss, MSELoss, CrossEntropyLoss sys.path.append(".") sys.path.append("..") sys.path.append("../..") sys.path.append("../../src") sys.path.append("../../src/common") try: from loss import * from utils import * from multi_label_metrics import * from metrics import * except ImportError: from src.common.loss import * from src.utils import * from src.common.multi_label_metrics import * from src.common.metrics import * logger = logging.getLogger(__name__) class VirHunter(nn.Module): def __init__(self, config, args): super(VirHunter, self).__init__() self.one_hot_encode = config.one_hot_encode self.vocab_size = config.vocab_size self.max_position_embeddings = config.max_position_embeddings self.embedding_trainable = config.embedding_trainable self.embedding_dim = config.embedding_dim self.num_labels = config.num_labels self.kernel_num = config.kernel_num self.kernel_size = config.kernel_size self.dropout = config.dropout self.reverse = config.reverse self.bias = config.bias self.fc_size = config.fc_size self.output_mode = args.output_mode self.padding_idx = config.padding_idx if self.num_labels == 2: self.num_labels = 1 if not self.one_hot_encode: self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx) if self.embedding_trainable: self.embedding.weight.requires_grad = True else: self.embedding.weight.requires_grad = False self.hidden_layers = nn.ModuleList( [ nn.Conv1d(self.vocab_size if self.one_hot_encode else self.embedding_dim, self.kernel_num, self.kernel_size, bias=self.bias), nn.LeakyReLU(negative_slope=0.1), nn.MaxPool1d(self.max_position_embeddings - self.kernel_size + 1, stride=1), nn.Dropout(self.dropout) ] ) if self.reverse: self.dense = nn.Linear(self.kernel_num + self.kernel_num, self.fc_size, bias=self.bias) else: self.dense = nn.Linear(self.kernel_num, self.fc_size, bias=self.bias) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.1) self.linear_layer = nn.Linear(self.fc_size, self.num_labels, bias=self.bias) if args.sigmoid: self.output = nn.Sigmoid() else: if self.num_labels > 1: self.output = nn.Softmax(dim=1) else: self.output = None self.loss_type = args.loss_type # positive weight if hasattr(config, "pos_weight"): self.pos_weight = config.pos_weight elif hasattr(args, "pos_weight"): self.pos_weight = args.pos_weight else: self.pos_weight = None if hasattr(config, "weight"): self.weight = config.weight elif hasattr(args, "weight"): self.weight = args.weight else: self.weight = None if self.output_mode in ["regression"]: self.loss_fct = MSELoss() elif self.output_mode in ["multi_label", "multi-label"]: if self.loss_type == "bce": if self.pos_weight: # [1, 1, 1, ,1, 1...] length: self.num_labels assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == self.num_labels self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) else: self.loss_fct = BCEWithLogitsLoss(reduction=config.loss_reduction if hasattr(config, "loss_reduction") else "sum") elif self.loss_type == "asl": self.loss_fct = AsymmetricLossOptimized(gamma_neg=args.asl_gamma_neg if hasattr(args, "asl_gamma_neg") else 4, gamma_pos=args.asl_gamma_pos if hasattr(args, "asl_gamma_pos") else 1, clip=args.clip if hasattr(args, "clip") else 0.05, eps=args.eps if hasattr(args, "eps") else 1e-8, disable_torch_grad_focal_loss=args.disable_torch_grad_focal_loss if hasattr(args, "disable_torch_grad_focal_loss") else False) elif self.loss_type == "focal_loss": self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1, gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25, normalization=False, reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False) elif self.loss_type == "multilabel_cce": self.loss_fct = MultiLabel_CCE(normalization=False) elif self.output_mode in ["binary_class", "binary-class"]: if self.loss_type == "bce": if self.pos_weight: # [0.9] if isinstance(self.pos_weight, int): self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.long).to(args.device) elif isinstance(self.pos_weight, float): self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.float32).to(args.device) assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == 1 self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) else: self.loss_fct = BCEWithLogitsLoss() elif self.loss_type == "focal_loss": self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1, gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25, normalization=False, reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False) elif self.output_mode in ["multi_class", "multi-class"]: if self.weight: # [1, 1, 1, ,1, 1...] length: self.num_labels assert self.weight.ndim == 1 and self.weight.shape[0] == self.num_labels self.loss_fct = CrossEntropyLoss(weight=self.weight) else: self.loss_fct = CrossEntropyLoss() else: raise Exception("Not support output mode: %s." % self.output_mode) def forward(self, x, reverse_x=None, lengths=None, labels=None): if not self.one_hot_encode: x = self.embedding(x) else: x = one_hot(x, num_classes=self.vocab_size).to(torch.float32) x = x.permute(0, 2, 1) if reverse_x: for layer in self.hidden_layers: x = layer(x) for layer in self.hidden_layers: reverse_x = layer(reverse_x) x = torch.cat([x, reverse_x], dim=-1) else: for layer in self.hidden_layers: x = layer(x) x = torch.squeeze(x, -1) x = self.dense(x) x = self.relu(x) x = self.dropout(x) logits = self.linear_layer(x) if self.output: output = self.output(logits) else: output = logits outputs = [logits, output] if labels is not None: if self.output_mode in ["regression"]: loss = self.loss_fct(logits.view(-1), labels.view(-1)) elif self.output_mode in ["multi_label", "multi-label"]: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) elif self.output_mode in ["binary_class", "binary-class"]: loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) elif self.output_mode in ["multi_class", "multi-class"]: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) outputs = [loss, *outputs] return outputs def seq_encode(seq, max_len, trunc_type, vocab): seq_len = len(seq) if seq_len >= max_len: actural_len = max_len if trunc_type == "right": processed_seq = list(seq[:max_len]) else: processed_seq = list(seq[-max_len:]) else: actural_len = len(seq) processed_seq = list(seq) + ["[PAD]"] * (max_len - seq_len) processed_seq_id = [] for char in processed_seq: processed_seq_id.append(vocab[char]) return processed_seq_id, actural_len def one_hot_encode(seq, max_len, trunc_type, vocab): processed_seq_id, actural_len = seq_encode(seq, max_len, trunc_type, vocab) one_hot = [] for idx in processed_seq_id: cur_one_hot = [0.0] * len(vocab) cur_one_hot[idx] = 1.0 one_hot.append(cur_one_hot) return one_hot, actural_len if __name__ == "__main__": # protein_list = ["[PAD]", "I", "M", "T", "N", "K", "S", "R", "L", "P", "H", "Q", "V", "A", "D", "E", "G", "S", "F", "Y", "W", "C", "O"] protein_list = ["[PAD]", "A", "T", "C", "G"] int_to_protein = {idx: char for idx, char in enumerate(protein_list)} protein_to_int = {char: idx for idx, char in int_to_protein.items()} print(one_hot_encode("AATCG", 6, protein_to_int))