aiops/ContraLSP/mortalty/classifier.py (56 lines of code) (raw):
import torch as th
from torchmetrics import Accuracy, Precision, Recall, AUROC
from typing import Callable, Union
from hmm.classifier import StateClassifier
from tint.models import Net
class MimicClassifierNet(Net):
def __init__(
self,
feature_size: int,
n_state: int,
hidden_size: int,
rnn: str = "GRU",
dropout: float = 0.5,
regres: bool = True,
bidirectional: bool = False,
loss: Union[str, Callable] = "mse",
optim: str = "adam",
lr: float = 0.001,
lr_scheduler: Union[dict, str] = None,
lr_scheduler_args: dict = None,
l2: float = 0.0,
):
classifier = StateClassifier(
feature_size=feature_size,
n_state=n_state,
hidden_size=hidden_size,
rnn=rnn,
dropout=dropout,
regres=regres,
bidirectional=bidirectional,
)
super().__init__(
layers=classifier,
loss=loss,
optim=optim,
lr=lr,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
l2=l2,
)
self.save_hyperparameters()
for stage in ["train", "val", "test"]:
setattr(self, stage + "_acc", Accuracy(task="binary"))
setattr(self, stage + "_pre", Precision(task="binary"))
setattr(self, stage + "_rec", Recall(task="binary"))
setattr(self, stage + "_auroc", AUROC(task="binary"))
def forward(self, *args, **kwargs) -> th.Tensor:
return self.net(*args, **kwargs)
def step(self, batch, batch_idx, stage):
x, y = batch
y_hat = self(x)
loss = self.loss(y_hat, y)
for metric in ["acc", "pre", "rec", "auroc"]:
getattr(self, stage + "_" + metric)(y_hat[:, 1], y.long())
self.log(stage + "_" + metric, getattr(self, stage + "_" + metric))
return loss