aiops/ContraLSP/mortalty/classifier.py (56 lines of code) (raw):

import torch as th from torchmetrics import Accuracy, Precision, Recall, AUROC from typing import Callable, Union from hmm.classifier import StateClassifier from tint.models import Net class MimicClassifierNet(Net): def __init__( self, feature_size: int, n_state: int, hidden_size: int, rnn: str = "GRU", dropout: float = 0.5, regres: bool = True, bidirectional: bool = False, loss: Union[str, Callable] = "mse", optim: str = "adam", lr: float = 0.001, lr_scheduler: Union[dict, str] = None, lr_scheduler_args: dict = None, l2: float = 0.0, ): classifier = StateClassifier( feature_size=feature_size, n_state=n_state, hidden_size=hidden_size, rnn=rnn, dropout=dropout, regres=regres, bidirectional=bidirectional, ) super().__init__( layers=classifier, loss=loss, optim=optim, lr=lr, lr_scheduler=lr_scheduler, lr_scheduler_args=lr_scheduler_args, l2=l2, ) self.save_hyperparameters() for stage in ["train", "val", "test"]: setattr(self, stage + "_acc", Accuracy(task="binary")) setattr(self, stage + "_pre", Precision(task="binary")) setattr(self, stage + "_rec", Recall(task="binary")) setattr(self, stage + "_auroc", AUROC(task="binary")) def forward(self, *args, **kwargs) -> th.Tensor: return self.net(*args, **kwargs) def step(self, batch, batch_idx, stage): x, y = batch y_hat = self(x) loss = self.loss(y_hat, y) for metric in ["acc", "pre", "rec", "auroc"]: getattr(self, stage + "_" + metric)(y_hat[:, 1], y.long()) self.log(stage + "_" + metric, getattr(self, stage + "_" + metric)) return loss