# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility modules."""

from __future__ import absolute_import, division, print_function
import math

import torch

from easycv.core.sailfish.function import (all_cat, all_log_softmax,
                                           all_nll_loss, all_sum,
                                           shard_correct_mask,
                                           shard_correct_predictions,
                                           shard_target_and_mask,
                                           shard_topk_correct_predictions)
from easycv.framework.errors import NotImplementedError, ValueError


class DistributedParallel:
    """Base class of parallelism."""

    def __init__(self, rank, world_size):
        self._rank = rank
        self._world_size = world_size

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def correct_mask(self, target, inputs):
        mask = torch.zeros(
            inputs.size(), device=inputs.device, dtype=inputs.dtype)
        mask.scatter_(1, target.view(-1, 1).long(), 1)
        return mask

    def correct_predictions(self, target, logits, k=1):
        if k == 1:
            pred = torch.max(logits, dim=1)[1]
            return (pred == target.view(-1, 1)).sum().item()
        pred = torch.topk(logits, k, dim=1)[1]
        return (pred == target.view(-1, 1)).sum().item()

    def xavier_uniform_(self, weight, gain=1.):
        return torch.nn.init.xavier_uniform_(weight, gain=gain)


class ModelParallel(DistributedParallel):
    """All-to-All Model Parallelism."""

    def gather(self, inputs, dim=0, requires_grad=True):
        if requires_grad:
            return all_cat(
                inputs, dim=dim, rank=self.rank, world_size=self.world_size)
        all_inputs = [
            torch.zeros(
                inputs.size(), dtype=inputs.dtype, device=inputs.device)
            for _ in range(self.world_size)
        ]
        torch.distributed.all_gather(all_inputs, inputs)
        return torch.cat(all_inputs, dim=dim)

    def gather_target(self, target):
        return self.gather(target, requires_grad=False)

    def reduce_sum(self, inputs):
        return all_sum(inputs)

    def log_softmax(self, logits, epsilon=1e-8):
        return all_log_softmax(logits, epsilon=epsilon)

    def nll_loss(self, inputs, correct_mask):
        return all_nll_loss(inputs, correct_mask)

    def correct_mask(self, target, inputs):
        return shard_correct_mask(target, inputs, rank=self.rank)

    def target_and_mask(self, target, output_features):
        return shard_target_and_mask(target, output_features, rank=self.rank)

    def correct_predictions(self, target, logits, k=1):
        if k == 1:
            return shard_correct_predictions(
                target, logits, world_size=self.world_size)
        return shard_topk_correct_predictions(
            target, logits, k, world_size=self.world_size)


class ParameterInitializer:
    r"""Base class for parameter initializer."""

    def __call__(self, param, shard_rank=0, num_shards=1):
        raise NotImplementedError


class ZerosInitializer(ParameterInitializer):

    def __call__(self, param, parallel=None):
        torch.nn.init.zeros_(param)


class OnesInitializer(ParameterInitializer):

    def __call__(self, param, parallel=None):
        torch.nn.init.ones_(param)


class XavierUniformInitializer(ParameterInitializer):

    def __init__(self, gain=1.):
        self.gain = gain

    def __call__(self, param, parallel=None):
        if isinstance(parallel, ModelParallel):
            if param.dim() != 2:
                raise ValueError(
                    'param with dimensions other than 2 not supported')
            r = param.size(1) + param.size(0) * parallel.world_size
            a = self.gain * math.sqrt(3.0) * math.sqrt(2.0 / float(r))
            torch.nn.init.uniform_(param, -a, a)
        else:
            torch.nn.init.xavier_uniform_(param, gain=self.gain)


class KaimingUniformInitializer(ParameterInitializer):

    def __init__(self, bound):
        self.bound = bound

    def __call__(self, param, parallel=None):
        torch.nn.init.kaiming_uniform_(param, a=self.bound)


class BiasUniformInitializer(ParameterInitializer):

    def __init__(self, weight_in_features):
        self.weight_in_features = weight_in_features

    def __call__(self, param, parallel=None):
        a = 1 / math.sqrt(float(self.weight_in_features))
        torch.nn.init.uniform_(param, -a, a)


class RenormUniformInitializer(ParameterInitializer):

    def __init__(self, maxnorm=1e-5, scale=1e5):
        self.maxnorm = maxnorm
        self.scale = scale

    def __call__(self, param, parallel=None):
        param.data.uniform_(-1, 1).renorm_(
            2, 0, maxnorm=self.maxnorm).mul_(self.scale)


class NormalInitializer(ParameterInitializer):

    def __call__(self, param, parallel=None):
        torch.nn.init.normal_(param)
