# Copyright (c) Alibaba, Inc. and its affiliates.
# Reference: https://github.com/Alpha-VL/FastConvMAE
from functools import partial

import numpy as np
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_

from easycv.models.registry import BACKBONES
from easycv.models.utils import DropPath
from easycv.models.utils.pos_embed import get_2d_sincos_pos_embed
from .vision_transformer import Block


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding.
    Args:
        img_size (int | tuple): The size of input image
        patch_size (int | tiple): The size of one patch
        in_channels (int): The num of input channels. Default: 3
        embed_dims (int): The dimensions of embedding. Default: 768
    """

    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 embed_dim=768):
        super().__init__()
        if not isinstance(img_size, (list, tuple)):
            img_size = (img_size, img_size)
        if not isinstance(patch_size, (list, tuple)):
            patch_size = (patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],
                          img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
        self.act = nn.GELU()

    def forward(self, x):
        _, _, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[
            1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        return self.act(x)


class ConvMlp(nn.Module):

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class ConvBlock(nn.Module):

    def __init__(self,
                 dim,
                 mlp_ratio=4.,
                 drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.conv1 = nn.Conv2d(dim, dim, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)
        self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ConvMlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop)

    def forward(self, x, mask=None):
        if mask is not None:
            residual = x
            x = self.conv1(
                self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
            x1 = self.attn(mask[0] * x)
            x2 = self.attn(mask[1] * x)
            x3 = self.attn(mask[2] * x)
            x4 = self.attn(mask[3] * x)
            x = mask[0] * x1 + mask[1] * x2 + mask[2] * x3 + mask[3] * x4
            x = residual + self.drop_path(self.conv2(x))
        else:
            x = x + self.drop_path(
                self.conv2(
                    self.attn(
                        self.conv1(
                            self.norm1(x.permute(0, 2, 3, 1)).permute(
                                0, 3, 1, 2)))))
        x = x + self.drop_path(
            self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)))
        return x


@BACKBONES.register_module
class FastConvMAEViT(nn.Module):
    """ Fast ConvMAE framework is a superiorly fast masked modeling scheme via
    complementary masking and mixture of reconstrunctors based on the ConvMAE(https://arxiv.org/abs/2205.03892).

    Args:
        img_size (list | tuple): Input image size for three stages.
        patch_size (list | tuple): The patch size for three stages.
        in_channels (int): The num of input channels. Default: 3
        embed_dim (list | tuple): The dimensions of embedding for three stages.
        depth (list | tuple): depth for three stages.
        num_heads (int): Parallel attention heads
        mlp_ratio (list | tuple): Mlp expansion ratio.
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        norm_layer (nn.Module): normalization layer
        init_pos_embed_by_sincos: initialize pos_embed by sincos strategy
        with_fuse(bool): Whether to use fuse layers.
        global_pool: global pool

    """

    def __init__(
        self,
        img_size=[224, 56, 28],
        patch_size=[4, 2, 2],
        in_channels=3,
        embed_dim=[256, 384, 768],
        depth=[2, 2, 11],
        num_heads=12,
        mlp_ratio=[4, 4, 4],
        drop_rate=0.,
        drop_path_rate=0.1,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        init_pos_embed_by_sincos=True,
        with_fuse=True,
        global_pool=False,
    ):

        super().__init__()
        self.init_pos_embed_by_sincos = init_pos_embed_by_sincos
        self.with_fuse = with_fuse
        self.global_pool = global_pool

        assert len(img_size) == len(patch_size) == len(embed_dim) == len(
            mlp_ratio)

        self.patch_size = patch_size[0] * patch_size[1] * patch_size[2]
        self.patch_embed1 = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size[0],
            in_channels=in_channels,
            embed_dim=embed_dim[0])
        self.patch_embed2 = PatchEmbed(
            img_size=img_size[1],
            patch_size=patch_size[1],
            in_channels=embed_dim[0],
            embed_dim=embed_dim[1])
        self.patch_embed3 = PatchEmbed(
            img_size=img_size[2],
            patch_size=patch_size[2],
            in_channels=embed_dim[1],
            embed_dim=embed_dim[2])

        self.patch_embed4 = nn.Linear(embed_dim[2], embed_dim[2])

        if with_fuse:
            self._make_fuse_layers(embed_dim)

        self.num_patches = self.patch_embed3.num_patches
        if init_pos_embed_by_sincos:
            self.pos_embed = nn.Parameter(
                torch.zeros(1, self.num_patches, embed_dim[2]),
                requires_grad=False)
        else:
            self.pos_embed = nn.Parameter(
                torch.zeros(1, self.num_patches, embed_dim[2]), )

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, sum(depth))

        self.blocks1 = nn.ModuleList([
            ConvBlock(
                dim=embed_dim[0],
                mlp_ratio=mlp_ratio[0],
                drop=drop_rate,
                drop_path=dpr[i],
            ) for i in range(depth[0])
        ])
        self.blocks2 = nn.ModuleList([
            ConvBlock(
                dim=embed_dim[1],
                mlp_ratio=mlp_ratio[1],
                drop=drop_rate,
                drop_path=dpr[depth[0] + i],
            ) for i in range(depth[1])
        ])
        self.blocks3 = nn.ModuleList([
            Block(
                dim=embed_dim[2],
                num_heads=num_heads,
                mlp_ratio=mlp_ratio[2],
                qkv_bias=True,
                qk_scale=None,
                drop=drop_rate,
                drop_path=dpr[depth[0] + depth[1] + i],
                norm_layer=norm_layer) for i in range(depth[2])
        ])

        if self.global_pool:
            self.fc_norm = norm_layer(embed_dim[-1])
            self.norm = None
        else:
            self.norm = norm_layer(embed_dim[-1])
            self.fc_norm = None

    def init_weights(self):
        if self.init_pos_embed_by_sincos:
            # initialize (and freeze) pos_embed by sin-cos embedding
            pos_embed = get_2d_sincos_pos_embed(
                self.pos_embed.shape[-1],
                int(self.num_patches**.5),
                cls_token=False)
            self.pos_embed.data.copy_(
                torch.from_numpy(pos_embed).float().unsqueeze(0))
        else:
            trunc_normal_(self.pos_embed, std=.02)

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed3.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _make_fuse_layers(self, embed_dim):
        self.stage1_output_decode = nn.Conv2d(
            embed_dim[0], embed_dim[2], 4, stride=4)
        self.stage2_output_decode = nn.Conv2d(
            embed_dim[1], embed_dim[2], 2, stride=2)

    def random_masking(self, x, mask_ratio=None):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N = x.shape[0]
        L = self.num_patches
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep1 = ids_shuffle[:, :len_keep]
        ids_keep2 = ids_shuffle[:, len_keep:2 * len_keep]
        ids_keep3 = ids_shuffle[:, 2 * len_keep:3 * len_keep]
        ids_keep4 = ids_shuffle[:, 3 * len_keep:]

        # generate the binary mask: 0 is keep, 1 is remove
        mask1 = torch.ones([N, L], device=x.device)
        mask1[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask1 = torch.gather(mask1, dim=1, index=ids_restore)

        mask2 = torch.ones([N, L], device=x.device)
        mask2[:, len_keep:2 * len_keep] = 0
        # unshuffle to get the binary mask
        mask2 = torch.gather(mask2, dim=1, index=ids_restore)

        mask3 = torch.ones([N, L], device=x.device)
        mask3[:, 2 * len_keep:3 * len_keep] = 0
        # unshuffle to get the binary mask
        mask3 = torch.gather(mask3, dim=1, index=ids_restore)

        mask4 = torch.ones([N, L], device=x.device)
        mask4[:, 3 * len_keep:4 * len_keep] = 0
        # unshuffle to get the binary mask
        mask4 = torch.gather(mask4, dim=1, index=ids_restore)

        return [ids_keep1, ids_keep2, ids_keep3,
                ids_keep4], [mask1, mask2, mask3, mask4], ids_restore

    def _fuse_forward(self, s1, s2, ids_keep=None, mask_ratio=None):
        stage1_embed = self.stage1_output_decode(s1).flatten(2).permute(
            0, 2, 1)
        stage2_embed = self.stage2_output_decode(s2).flatten(2).permute(
            0, 2, 1)

        if mask_ratio is not None:
            stage1_embed_1 = torch.gather(
                stage1_embed,
                dim=1,
                index=ids_keep[0].unsqueeze(-1).repeat(1, 1,
                                                       stage1_embed.shape[-1]))
            stage2_embed_1 = torch.gather(
                stage2_embed,
                dim=1,
                index=ids_keep[0].unsqueeze(-1).repeat(1, 1,
                                                       stage2_embed.shape[-1]))
            stage1_embed_2 = torch.gather(
                stage1_embed,
                dim=1,
                index=ids_keep[1].unsqueeze(-1).repeat(1, 1,
                                                       stage1_embed.shape[-1]))
            stage2_embed_2 = torch.gather(
                stage2_embed,
                dim=1,
                index=ids_keep[1].unsqueeze(-1).repeat(1, 1,
                                                       stage2_embed.shape[-1]))
            stage1_embed_3 = torch.gather(
                stage1_embed,
                dim=1,
                index=ids_keep[2].unsqueeze(-1).repeat(1, 1,
                                                       stage1_embed.shape[-1]))
            stage2_embed_3 = torch.gather(
                stage2_embed,
                dim=1,
                index=ids_keep[2].unsqueeze(-1).repeat(1, 1,
                                                       stage2_embed.shape[-1]))
            stage1_embed_4 = torch.gather(
                stage1_embed,
                dim=1,
                index=ids_keep[3].unsqueeze(-1).repeat(1, 1,
                                                       stage1_embed.shape[-1]))
            stage2_embed_4 = torch.gather(
                stage2_embed,
                dim=1,
                index=ids_keep[3].unsqueeze(-1).repeat(1, 1,
                                                       stage2_embed.shape[-1]))

            stage1_embed = torch.cat([
                stage1_embed_1, stage1_embed_2, stage1_embed_3, stage1_embed_4
            ])
            stage2_embed = torch.cat([
                stage2_embed_1, stage2_embed_2, stage2_embed_3, stage2_embed_4
            ])

        return stage1_embed, stage2_embed

    def forward(self, x, mask_ratio=None):
        if mask_ratio is not None:
            assert self.with_fuse

        # embed patches
        if mask_ratio is not None:
            ids_keep, masks, ids_restore = self.random_masking(x, mask_ratio)
            mask_for_patch1 = [
                1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(
                    1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute(
                        0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1)
                for mask in masks
            ]
            mask_for_patch2 = [
                1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(
                    1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute(
                        0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1)
                for mask in masks
            ]
        else:
            mask_for_patch1 = None
            mask_for_patch2 = None

        s1 = self.patch_embed1(x)
        s1 = self.pos_drop(s1)
        for blk in self.blocks1:
            s1 = blk(s1, mask_for_patch1)

        s2 = self.patch_embed2(s1)
        for blk in self.blocks2:
            s2 = blk(s2, mask_for_patch2)

        if self.with_fuse:
            stage1_embed, stage2_embed = self._fuse_forward(
                s1, s2, ids_keep, mask_ratio)

        x = self.patch_embed3(s2)
        x = x.flatten(2).permute(0, 2, 1)
        x = self.patch_embed4(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed

        if mask_ratio is not None:
            x1 = torch.gather(
                x,
                dim=1,
                index=ids_keep[0].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
            x2 = torch.gather(
                x,
                dim=1,
                index=ids_keep[1].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
            x3 = torch.gather(
                x,
                dim=1,
                index=ids_keep[2].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
            x4 = torch.gather(
                x,
                dim=1,
                index=ids_keep[3].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
            x = torch.cat([x1, x2, x3, x4])

        # apply Transformer blocks
        for blk in self.blocks3:
            x = blk(x)

        if self.with_fuse:
            x = x + stage1_embed + stage2_embed

        if self.global_pool:
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            x = self.fc_norm(x)
        else:
            x = self.norm(x)

        if mask_ratio is not None:
            mask = torch.cat([masks[0], masks[1], masks[2], masks[3]])
            return x, mask, ids_restore

        return x, None, None
