# Copyright (c) Facebook, Inc. and its affiliates.

import sys
from collections import OrderedDict
from functools import partial

import torch.nn as nn
from inplace_abn import ABN
from modules import GlobalAvgPool2d, DenseModule

from .util import try_index


class DenseNet(nn.Module):
    def __init__(
        self,
        structure,
        norm_act=ABN,
        input_3x3=False,
        growth=32,
        theta=0.5,
        classes=0,
        dilation=1,
    ):
        """DenseNet

        Parameters
        ----------
        structure : list of int
            Number of layers in each of the four dense blocks of the network.
        norm_act : callable
            Function to create normalization / activation Module.
        input_3x3 : bool
            If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one.
        growth : int
            Number of channels in each layer, i.e. the "growth" factor of the DenseNet.
        theta : float
            Reduction factor for the transition blocks.
        classes : int
            If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
            of the network.
        dilation : int or list of int
            List of dilation factors, or `1` to ignore dilation. If the dilation factor for a module is greater than `1`
            skip the pooling in the transition block right before it.
        """
        super(DenseNet, self).__init__()
        self.structure = structure
        if len(structure) != 4:
            raise ValueError("Expected a structure with four values")

        # Initial layers
        if input_3x3:
            layers = [
                ("conv1", nn.Conv2d(3, growth * 2, 3, stride=2, padding=1, bias=False)),
                ("bn1", norm_act(growth * 2)),
                (
                    "conv2",
                    nn.Conv2d(
                        growth * 2, growth * 2, 3, stride=1, padding=1, bias=False
                    ),
                ),
                ("bn2", norm_act(growth * 2)),
                (
                    "conv3",
                    nn.Conv2d(
                        growth * 2, growth * 2, 3, stride=1, padding=1, bias=False
                    ),
                ),
                ("pool", nn.MaxPool2d(3, stride=2, padding=1)),
            ]
        else:
            layers = [
                ("conv1", nn.Conv2d(3, growth * 2, 7, stride=2, padding=3, bias=False)),
                ("pool", nn.MaxPool2d(3, stride=2, padding=1)),
            ]
        self.mod1 = nn.Sequential(OrderedDict(layers))

        in_channels = growth * 2
        for mod_id in range(4):
            d = try_index(dilation, mod_id)
            s = 2 if d == 1 and mod_id > 0 else 1

            # Create transition module
            if mod_id > 0:
                out_channels = int(in_channels * theta)
                layers = [
                    ("bn", norm_act(in_channels)),
                    ("conv", nn.Conv2d(in_channels, out_channels, 1, bias=False)),
                ]
                if s == 2:
                    layers.append(("pool", nn.AvgPool2d(2, 2)))
                self.add_module(
                    "tra%d" % (mod_id + 1), nn.Sequential(OrderedDict(layers))
                )
                in_channels = out_channels

            # Create dense module
            mod = DenseModule(
                in_channels, growth, structure[mod_id], norm_act=norm_act, dilation=d
            )
            self.add_module("mod%d" % (mod_id + 2), mod)
            in_channels = mod.out_channels

        # Pooling and predictor
        self.bn_out = norm_act(in_channels)
        if classes != 0:
            self.classifier = nn.Sequential(
                OrderedDict(
                    [
                        ("avg_pool", GlobalAvgPool2d()),
                        ("fc", nn.Linear(in_channels, classes)),
                    ]
                )
            )

    def forward(self, x):
        x = self.mod1(x)
        x = self.mod2(x)
        x = self.tra2(x)
        x = self.mod3(x)
        x = self.tra3(x)
        x = self.mod4(x)
        x = self.tra4(x)
        x = self.mod5(x)
        x = self.bn_out(x)

        if hasattr(self, "classifier"):
            x = self.classifier(x)
        return x


_NETS = {
    "121": {"structure": [6, 12, 24, 16]},
    "169": {"structure": [6, 12, 32, 32]},
    "201": {"structure": [6, 12, 48, 32]},
    "264": {"structure": [6, 12, 64, 48]},
}

__all__ = []
for name, params in _NETS.items():
    net_name = "net_densenet" + name
    setattr(sys.modules[__name__], net_name, partial(DenseNet, **params))
    __all__.append(net_name)
