in aiops/ContraAD/model/PointAttention.py [0:0]
def cal_metric(x,z_score,mode='z-score',soft=True,soft_mode='min',model_mode='train'):
if mode =='z-score_mae':
dis = torch.cdist(x,x).sum(2)
if soft:
if soft_mode=='sum':
val = dis.sum(dim=1)
val = repeat(val,"b -> b w", w=dis.size(1))
dis = normalize(dis/val) # batch,win
elif soft_mode == 'min':
val,_ = dis.min(dim=1)
val = repeat(val,"b -> b w" ,w=dis.size(1))
dis = normalize(dis/val) # batch,win
if model_mode =='train':
return F.l1_loss(dis,z_score,reduction='mean'),dis
else:
return dis
elif mode == 'z_score_mse':
dis = torch.cdist(x,x).sum(2)
if soft:
if soft_mode=='sum':
val = dis.sum(dim=1)
val = repeat(val,"b -> b w", w=dis.size(1))
dis = normalize(dis/val) # batch,win
elif soft_mode == 'min':
val,_ = dis.min(dim=1)
val = repeat(val,"b -> b w" ,w=dis.size(1))
dis = normalize(dis/val) # batch,win
if model_mode =='train':
return F.mse_loss(dis,z_score,reduction='mean'),dis
else:
return dis
elif mode == 'z_score_clamp':
dis = torch.cdist(x,x).sum(2)
if soft:
if soft_mode=='sum':
val = dis.sum(dim=1)
val = repeat(val,"b -> b w", w=dis.size(1))
dis = normalize(dis/val) # batch,win
elif soft_mode == 'min':
val,_ = dis.min(dim=1)
val = repeat(val,"b -> b w" ,w=dis.size(1))
dis = normalize(dis/val) # batch,win
if model_mode == 'train':
return torch.where(dis>z_score,dis,z_score-dis).sum(dim=1).mean(),dis
else:
return dis
elif mode == 'distance':
dis = torch.cdist(x,x).sum(2)
if soft:
if soft_mode=='sum':
val = dis.sum(dim=1)
val = repeat(val,"b -> b w", w=dis.size(1))
dis = normalize(dis/val) # batch,win
elif soft_mode == 'min':
val,_ = dis.min(dim=1)
val = repeat(val,"b -> b w" ,w=dis.size(1))
dis = normalize(dis/val) # batch,win
if model_mode == 'train':
return dis.sum(dim=1).mean(),dis
else:
return dis