aiops/ContrastiveLearningLogClustering/utils/losses.py (54 lines of code) (raw):
import torch.nn as nn
from sentence_transformers import SentenceTransformer, util
from torch import Tensor
from typing import Iterable, Dict
import torch
def calculate_center(model,all_event_log):
event_center = {}
for event in all_event_log:
corpus = all_event_log[event]
total_embeddings = 0
with torch.no_grad():
embeddings = model.encode(corpus,convert_to_numpy=False, convert_to_tensor=True,normalize_embeddings=True).clone()
total_embeddings += torch.sum(embeddings, dim=0)
center = total_embeddings/len(corpus)
event_center[event] = center
return event_center
class MNR_Hyper_Loss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20, similarity_fct = util.cos_sim, hyper_ratio = 0.01, log_to_event={}, event_center={}):
super(MNR_Hyper_Loss, self).__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.mse_loss = nn.MSELoss()
self.hyper_ratio = hyper_ratio
self.log_to_event = log_to_event
self.event_center = event_center
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
# print(sentence_features)
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
# print("reps:",len(reps))
embeddings_a = reps[0] #(b,768)
# print("embeddings_a:",embeddings_a.size())
embeddings_b = torch.cat(reps[1:]) # (b,768)
# print("embeddings_b:",embeddings_b.size())
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale #(b,b)
# print("scores:",scores.size())
MNR_labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
MNR_loss = self.cross_entropy_loss(scores, MNR_labels)
if self.hyper_ratio==0:
return MNR_loss
center_embeddings = []
for b in range(len(sentence_features[0]['input_ids'])):
log_token = sentence_features[0]['input_ids'][b]
token_mask = sentence_features[0]['attention_mask'][b]
log_token = log_token.cpu().numpy()
token_mask = token_mask.cpu().numpy()
log_token = log_token[token_mask!=0]
event_id = self.log_to_event[tuple(log_token.tolist())]
center = self.event_center[event_id].unsqueeze(dim=0)
center_embeddings.append(center)
center_embeddings = torch.cat(center_embeddings,dim=0)
hyper_similarity = torch.cosine_similarity(embeddings_a,center_embeddings)
hyper_labels = torch.ones_like(hyper_similarity,device=scores.device)
# hyper_loss = self.mse_loss(embeddings_a,center_embeddings)
hyper_loss = self.mse_loss(hyper_similarity,hyper_labels)
loss = MNR_loss + self.hyper_ratio*hyper_loss
return loss
def get_config_dict(self):
return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__}