# Copyright (c) Facebook, Inc. and its affiliates.

import sys
from collections import OrderedDict
from functools import partial

import torch.nn as nn
from inplace_abn import ABN
from modules import IdentityResidualBlock, GlobalAvgPool2d


class WiderResNet(nn.Module):
    def __init__(self, structure, norm_act=ABN, classes=0):
        """Wider ResNet with pre-activation (identity mapping) blocks

        Parameters
        ----------
        structure : list of int
            Number of residual blocks in each of the six modules of the network.
        norm_act : callable
            Function to create normalization / activation Module.
        classes : int
            If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
            of the network.
        """
        super(WiderResNet, self).__init__()
        self.structure = structure

        if len(structure) != 6:
            raise ValueError("Expected a structure with six values")

        # Initial layers
        self.mod1 = nn.Sequential(
            OrderedDict(
                [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))]
            )
        )

        # Groups of residual blocks
        in_channels = 64
        channels = [
            (128, 128),
            (256, 256),
            (512, 512),
            (512, 1024),
            (512, 1024, 2048),
            (1024, 2048, 4096),
        ]
        for mod_id, num in enumerate(structure):
            # Create blocks for module
            blocks = []
            for block_id in range(num):
                blocks.append(
                    (
                        "block%d" % (block_id + 1),
                        IdentityResidualBlock(
                            in_channels, channels[mod_id], norm_act=norm_act
                        ),
                    )
                )

                # Update channels and p_keep
                in_channels = channels[mod_id][-1]

            # Create module
            if mod_id <= 4:
                self.add_module(
                    "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)
                )
            self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))

        # Pooling and predictor
        self.bn_out = norm_act(in_channels)
        if classes != 0:
            self.classifier = nn.Sequential(
                OrderedDict(
                    [
                        ("avg_pool", GlobalAvgPool2d()),
                        ("fc", nn.Linear(in_channels, classes)),
                    ]
                )
            )

    def forward(self, img):
        out = self.mod1(img)
        out = self.mod2(self.pool2(out))
        out = self.mod3(self.pool3(out))
        out = self.mod4(self.pool4(out))
        out = self.mod5(self.pool5(out))
        out = self.mod6(self.pool6(out))
        out = self.mod7(out)
        out = self.bn_out(out)

        if hasattr(self, "classifier"):
            out = self.classifier(out)

        return out


class WiderResNetA2(nn.Module):
    def __init__(self, structure, norm_act=ABN, classes=0, dilation=False):
        """Wider ResNet with pre-activation (identity mapping) blocks

        This variant uses down-sampling by max-pooling in the first two blocks and by strided convolution in the others.

        Parameters
        ----------
        structure : list of int
            Number of residual blocks in each of the six modules of the network.
        norm_act : callable
            Function to create normalization / activation Module.
        classes : int
            If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
            of the network.
        dilation : bool
            If `True` apply dilation to the last three modules and change the down-sampling factor from 32 to 8.
        """
        super(WiderResNetA2, self).__init__()
        self.structure = structure
        self.dilation = dilation

        if len(structure) != 6:
            raise ValueError("Expected a structure with six values")

        # Initial layers
        self.mod1 = nn.Sequential(
            OrderedDict(
                [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))]
            )
        )

        # Groups of residual blocks
        in_channels = 64
        channels = [
            (128, 128),
            (256, 256),
            (512, 512),
            (512, 1024),
            (512, 1024, 2048),
            (1024, 2048, 4096),
        ]
        for mod_id, num in enumerate(structure):
            # Create blocks for module
            blocks = []
            for block_id in range(num):
                if not dilation:
                    dil = 1
                    stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1
                else:
                    if mod_id == 3:
                        dil = 2
                    elif mod_id > 3:
                        dil = 4
                    else:
                        dil = 1
                    stride = 2 if block_id == 0 and mod_id == 2 else 1

                if mod_id == 4:
                    drop = partial(nn.Dropout2d, p=0.3)
                elif mod_id == 5:
                    drop = partial(nn.Dropout2d, p=0.5)
                else:
                    drop = None

                blocks.append(
                    (
                        "block%d" % (block_id + 1),
                        IdentityResidualBlock(
                            in_channels,
                            channels[mod_id],
                            norm_act=norm_act,
                            stride=stride,
                            dilation=dil,
                            dropout=drop,
                        ),
                    )
                )

                # Update channels and p_keep
                in_channels = channels[mod_id][-1]

            # Create module
            if mod_id < 2:
                self.add_module(
                    "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)
                )
            self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))

        # Pooling and predictor
        self.bn_out = norm_act(in_channels)
        if classes != 0:
            self.classifier = nn.Sequential(
                OrderedDict(
                    [
                        ("avg_pool", GlobalAvgPool2d()),
                        ("fc", nn.Linear(in_channels, classes)),
                    ]
                )
            )

    def forward(self, img):
        out = self.mod1(img)
        out = self.mod2(self.pool2(out))
        out = self.mod3(self.pool3(out))
        out = self.mod4(out)
        out = self.mod5(out)
        out = self.mod6(out)
        out = self.mod7(out)
        out = self.bn_out(out)

        if hasattr(self, "classifier"):
            return self.classifier(out)
        else:
            return out


_NETS = {
    "16": {"structure": [1, 1, 1, 1, 1, 1]},
    "20": {"structure": [1, 1, 1, 3, 1, 1]},
    "38": {"structure": [3, 3, 6, 3, 1, 1]},
}

__all__ = []
for name, params in _NETS.items():
    net_name = "net_wider_resnet" + name
    setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params))
    __all__.append(net_name)
for name, params in _NETS.items():
    net_name = "net_wider_resnet" + name + "_a2"
    setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params))
    __all__.append(net_name)
