def cal_metric()

in aiops/ContraAD/model/PointAttention.py [0:0]


def cal_metric(x,z_score,mode='z-score',soft=True,soft_mode='min',model_mode='train'):
    if mode =='z-score_mae':
        dis = torch.cdist(x,x).sum(2)
        if soft:
            if soft_mode=='sum':
                val = dis.sum(dim=1)
                val = repeat(val,"b -> b w", w=dis.size(1))
                dis = normalize(dis/val) # batch,win
            elif soft_mode == 'min':
                val,_ = dis.min(dim=1)
                val = repeat(val,"b -> b w" ,w=dis.size(1))
                dis = normalize(dis/val) # batch,win
        if model_mode =='train':
            return F.l1_loss(dis,z_score,reduction='mean'),dis
        else:
            return dis 
    elif mode == 'z_score_mse':
        dis = torch.cdist(x,x).sum(2)
        if soft:
            if soft_mode=='sum':
                val = dis.sum(dim=1)
                val = repeat(val,"b -> b w", w=dis.size(1))
                dis = normalize(dis/val) # batch,win
            elif soft_mode == 'min':
                val,_ = dis.min(dim=1)
                val = repeat(val,"b -> b w" ,w=dis.size(1))
                dis = normalize(dis/val) # batch,win
        if model_mode =='train':
            return F.mse_loss(dis,z_score,reduction='mean'),dis
        else:
            return dis 
    
    elif mode == 'z_score_clamp':
        dis = torch.cdist(x,x).sum(2)
        if soft:
            if soft_mode=='sum':
                val = dis.sum(dim=1)
                val = repeat(val,"b -> b w", w=dis.size(1))
                dis = normalize(dis/val) # batch,win
            elif soft_mode == 'min':
                val,_ = dis.min(dim=1)
                val = repeat(val,"b -> b w" ,w=dis.size(1))
                dis = normalize(dis/val) # batch,win
        if model_mode == 'train':
            return torch.where(dis>z_score,dis,z_score-dis).sum(dim=1).mean(),dis
        else:
            return dis 

    elif mode == 'distance':
        dis = torch.cdist(x,x).sum(2)
        if soft:
            if soft_mode=='sum':
                val = dis.sum(dim=1)
                val = repeat(val,"b -> b w", w=dis.size(1))
                dis = normalize(dis/val) # batch,win
            elif soft_mode == 'min':
                val,_ = dis.min(dim=1)
                val = repeat(val,"b -> b w" ,w=dis.size(1))
                dis = normalize(dis/val) # batch,win
        if model_mode == 'train':
            return dis.sum(dim=1).mean(),dis
        else:
            return dis