in src/deep_baselines/virseeker.py [0:0]
def __init__(self, config, args):
'''
:param config:
:param args:
'''
super(VirSeeker, self).__init__()
self.max_position_embeddings = config.max_position_embeddings
self.embedding = config.embedding
self.embedding_trainable = config.embedding_trainable
self.embedding_dim = config.embedding_dim
self.vocab_size = config.vocab_size
self.bidirectional = config.bidirectional
self.num_layers = config.num_layers
self.hidden_dim = config.hidden_dim
self.dropout = config.dropout
self.bias = config.bias
self.num_labels = config.num_labels
self.output_mode = args.output_mode
self.padding_idx = config.padding_idx
self.batch_first = config.batch_first
self.rnn_model = config.rnn_model
if hasattr(config, "padding_idx"):
self.padding_idx = config.padding_idx
elif hasattr(args, "padding_idx"):
self.padding_idx = args.padding_idx
else:
self.padding_idx = 0
if self.num_labels == 2:
self.num_labels = 1
if self.embedding:
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx)
if self.embedding_trainable:
self.embedding.weight.requires_grad = True
else:
self.embedding.weight.requires_grad = False
if self.num_layers == 1:
if self.rnn_model.lower() == 'lstm':
self.rnn = nn.LSTM(input_size=self.embedding_dim if self.embedding else 1, hidden_size=self.hidden_dim,
num_layers=self.num_layers, bidirectional=self.bidirectional,
batch_first=self.batch_first)
elif self.rnn_model.lower() == 'gru':
self.rnn = nn.GRU(input_size=self.embedding_dim if self.embedding else 1, hidden_size=self.hidden_dim,
num_layers=self.num_layers, bidirectional=self.bidirectional,
batch_first=self.batch_first)
else:
if self.rnn_model.lower() == 'lstm':
self.rnn = nn.LSTM(input_size=self.embedding_dim if self.embedding else 1, hidden_size=self.hidden_dim,
num_layers=self.num_layers, bidirectional=self.bidirectional,
batch_first=self.batch_first,
dropout=self.dropout)
elif self.rnn_model.lower() == 'gru':
self.rnn = nn.GRU(input_size=self.embedding_dim if self.embedding else 1, hidden_size=self.hidden_dim,
num_layers=self.num_layers, bidirectional=self.bidirectional,
batch_first=self.batch_first,
dropout=self.dropout)
self.linear_layer = nn.Linear(self.hidden_dim * 2 if self.bidirectional else self.hidden_dim, self.num_labels, bias=self.bias)
if args.sigmoid:
self.output = nn.Sigmoid()
else:
if self.num_labels > 1:
self.output = nn.Softmax(dim=1)
else:
self.output = None
self.loss_type = args.loss_type
# positive weight
if hasattr(config, "pos_weight"):
self.pos_weight = config.pos_weight
elif hasattr(args, "pos_weight"):
self.pos_weight = args.pos_weight
else:
self.pos_weight = None
if hasattr(config, "weight"):
self.weight = config.weight
elif hasattr(args, "weight"):
self.weight = args.weight
else:
self.weight = None
if self.output_mode in ["regression"]:
self.loss_fct = MSELoss()
elif self.output_mode in ["multi_label", "multi-label"]:
if self.loss_type == "bce":
if self.pos_weight:
# [1, 1, 1, ,1, 1...] length: self.num_labels
assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == self.num_labels
self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
else:
self.loss_fct = BCEWithLogitsLoss(reduction=config.loss_reduction if hasattr(config, "loss_reduction") else "sum")
elif self.loss_type == "asl":
self.loss_fct = AsymmetricLossOptimized(gamma_neg=args.asl_gamma_neg if hasattr(args, "asl_gamma_neg") else 4,
gamma_pos=args.asl_gamma_pos if hasattr(args, "asl_gamma_pos") else 1,
clip=args.clip if hasattr(args, "clip") else 0.05,
eps=args.eps if hasattr(args, "eps") else 1e-8,
disable_torch_grad_focal_loss=args.disable_torch_grad_focal_loss if hasattr(args, "disable_torch_grad_focal_loss") else False)
elif self.loss_type == "focal_loss":
self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1,
gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25,
normalization=False,
reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False)
elif self.loss_type == "multilabel_cce":
self.loss_fct = MultiLabel_CCE(normalization=False)
elif self.output_mode in ["binary_class", "binary-class"]:
if self.loss_type == "bce":
if self.pos_weight:
# [0.9]
if isinstance(self.pos_weight, int):
self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.long).to(args.device)
elif isinstance(self.pos_weight, float):
self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.float32).to(args.device)
assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == 1
self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
else:
self.loss_fct = BCEWithLogitsLoss()
elif self.loss_type == "focal_loss":
self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1,
gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25,
normalization=False,
reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False)
elif self.output_mode in ["multi_class", "multi-class"]:
if self.weight:
# [1, 1, 1, ,1, 1...] length: self.num_labels
assert self.weight.ndim == 1 and self.weight.shape[0] == self.num_labels
self.loss_fct = CrossEntropyLoss(weight=self.weight)
else:
self.loss_fct = CrossEntropyLoss()
else:
raise Exception("Not support output mode: %s." % self.output_mode)