# Copyright (c) Facebook, Inc. and its affiliates.

import torch
import torch.nn as nn
import torch.nn.functional as functional
from inplace_abn import ABN
from models.util import try_index


class DeeplabV3(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        hidden_channels=256,
        dilations=(12, 24, 36),
        norm_act=ABN,
        pooling_size=None,
    ):
        super(DeeplabV3, self).__init__()
        self.pooling_size = pooling_size

        self.map_convs = nn.ModuleList(
            [
                nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
                nn.Conv2d(
                    in_channels,
                    hidden_channels,
                    3,
                    bias=False,
                    dilation=dilations[0],
                    padding=dilations[0],
                ),
                nn.Conv2d(
                    in_channels,
                    hidden_channels,
                    3,
                    bias=False,
                    dilation=dilations[1],
                    padding=dilations[1],
                ),
                nn.Conv2d(
                    in_channels,
                    hidden_channels,
                    3,
                    bias=False,
                    dilation=dilations[2],
                    padding=dilations[2],
                ),
            ]
        )
        self.map_bn = norm_act(hidden_channels * 4)

        self.global_pooling_conv = nn.Conv2d(
            in_channels, hidden_channels, 1, bias=False
        )
        self.global_pooling_bn = norm_act(hidden_channels)

        self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
        self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
        self.red_bn = norm_act(out_channels)

        self.reset_parameters(self.map_bn.activation, self.map_bn.activation_param)

    def reset_parameters(self, activation, slope):
        gain = nn.init.calculate_gain(activation, slope)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, ABN):
                if hasattr(m, "weight") and m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Map convolutions
        out = torch.cat([m(x) for m in self.map_convs], dim=1)
        out = self.map_bn(out)
        out = self.red_conv(out)

        # Global pooling
        pool = self._global_pooling(x)
        pool = self.global_pooling_conv(pool)
        pool = self.global_pooling_bn(pool)
        pool = self.pool_red_conv(pool)
        if self.training or self.pooling_size is None:
            pool = pool.repeat(1, 1, x.size(2), x.size(3))

        out += pool
        out = self.red_bn(out)
        return out

    def _global_pooling(self, x):
        if self.training or self.pooling_size is None:
            pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
            pool = pool.view(x.size(0), x.size(1), 1, 1)
        else:
            pooling_size = (
                min(try_index(self.pooling_size, 0), x.shape[2]),
                min(try_index(self.pooling_size, 1), x.shape[3]),
            )
            padding = (
                (pooling_size[1] - 1) // 2,
                (pooling_size[1] - 1) // 2
                if pooling_size[1] % 2 == 1
                else (pooling_size[1] - 1) // 2 + 1,
                (pooling_size[0] - 1) // 2,
                (pooling_size[0] - 1) // 2
                if pooling_size[0] % 2 == 1
                else (pooling_size[0] - 1) // 2 + 1,
            )

            pool = functional.avg_pool2d(x, pooling_size, stride=1)
            pool = functional.pad(pool, pad=padding, mode="replicate")
        return pool
