# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from .. import builder
from ..registry import HEADS


@HEADS.register_module
class LatentPredictHead(nn.Module):
    '''Head for contrastive learning.
    '''

    def __init__(self, predictor, size_average=True):
        super(LatentPredictHead, self).__init__()
        self.predictor = builder.build_neck(predictor)
        self.size_average = size_average

    def init_weights(self, init_linear='normal'):
        self.predictor.init_weights(init_linear=init_linear)

    def forward(self, input, target):
        '''
        Args:
            input (Tensor): NxC input features.
            target (Tensor): NxC target features.
        '''
        pred = self.predictor([input])[0]
        pred_norm = nn.functional.normalize(pred, dim=1)
        target_norm = nn.functional.normalize(target, dim=1)
        loss = -2 * (pred_norm * target_norm).sum()
        if self.size_average:
            loss /= input.size(0)
        return dict(loss=loss)


@HEADS.register_module
class LatentClsHead(nn.Module):
    '''Head for contrastive learning.
    '''

    def __init__(self, predictor):
        super(LatentClsHead, self).__init__()
        self.predictor = nn.Linear(predictor.in_channels,
                                   predictor.num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def init_weights(self, init_linear='normal'):
        normal_init(self.predictor, std=0.01)

    def forward(self, input, target):
        '''
        Args:
            input (Tensor): NxC input features.
            target (Tensor): NxC target features.
        '''
        pred = self.predictor(input)
        with torch.no_grad():
            label = torch.argmax(self.predictor(target), dim=1).detach()
        loss = self.criterion(pred, label)
        return dict(loss=loss)
