from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import json
import logging
from collections import OrderedDict

from . import (
    fbnet_builder as mbuilder,
    fbnet_modeldef as modeldef,
)
import torch.nn as nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.rpn import rpn
from maskrcnn_benchmark.modeling import poolers


logger = logging.getLogger(__name__)


def create_builder(cfg):
    bn_type = cfg.MODEL.FBNET.BN_TYPE
    if bn_type == "gn":
        bn_type = (bn_type, cfg.GROUP_NORM.NUM_GROUPS)
    factor = cfg.MODEL.FBNET.SCALE_FACTOR

    arch = cfg.MODEL.FBNET.ARCH
    arch_def = cfg.MODEL.FBNET.ARCH_DEF
    if len(arch_def) > 0:
        arch_def = json.loads(arch_def)
    if arch in modeldef.MODEL_ARCH:
        if len(arch_def) > 0:
            assert (
                arch_def == modeldef.MODEL_ARCH[arch]
            ), "Two architectures with the same name {},\n{},\n{}".format(
                arch, arch_def, modeldef.MODEL_ARCH[arch]
            )
        arch_def = modeldef.MODEL_ARCH[arch]
    else:
        assert arch_def is not None and len(arch_def) > 0
    arch_def = mbuilder.unify_arch_def(arch_def)

    rpn_stride = arch_def.get("rpn_stride", None)
    if rpn_stride is not None:
        assert (
            cfg.MODEL.RPN.ANCHOR_STRIDE[0] == rpn_stride
        ), "Needs to set cfg.MODEL.RPN.ANCHOR_STRIDE to {}, got {}".format(
            rpn_stride, cfg.MODEL.RPN.ANCHOR_STRIDE
        )
    width_divisor = cfg.MODEL.FBNET.WIDTH_DIVISOR
    dw_skip_bn = cfg.MODEL.FBNET.DW_CONV_SKIP_BN
    dw_skip_relu = cfg.MODEL.FBNET.DW_CONV_SKIP_RELU

    logger.info(
        "Building fbnet model with arch {} (without scaling):\n{}".format(
            arch, arch_def
        )
    )

    builder = mbuilder.FBNetBuilder(
        width_ratio=factor,
        bn_type=bn_type,
        width_divisor=width_divisor,
        dw_skip_bn=dw_skip_bn,
        dw_skip_relu=dw_skip_relu,
    )

    return builder, arch_def


def _get_trunk_cfg(arch_def):
    """ Get all stages except the last one """
    num_stages = mbuilder.get_num_stages(arch_def)
    trunk_stages = arch_def.get("backbone", range(num_stages - 1))
    ret = mbuilder.get_blocks(arch_def, stage_indices=trunk_stages)
    return ret


class FBNetTrunk(nn.Module):
    def __init__(
        self, builder, arch_def, dim_in,
    ):
        super(FBNetTrunk, self).__init__()
        self.first = builder.add_first(arch_def["first"], dim_in=dim_in)
        trunk_cfg = _get_trunk_cfg(arch_def)
        self.stages = builder.add_blocks(trunk_cfg["stages"])

    # return features for each stage
    def forward(self, x):
        y = self.first(x)
        y = self.stages(y)
        ret = [y]
        return ret


@registry.BACKBONES.register("FBNet")
def add_conv_body(cfg, dim_in=3):
    builder, arch_def = create_builder(cfg)

    body = FBNetTrunk(builder, arch_def, dim_in)
    model = nn.Sequential(OrderedDict([("body", body)]))
    model.out_channels = builder.last_depth

    return model


def _get_rpn_stage(arch_def, num_blocks):
    rpn_stage = arch_def.get("rpn")
    ret = mbuilder.get_blocks(arch_def, stage_indices=rpn_stage)
    if num_blocks > 0:
        logger.warn('Use last {} blocks in {} as rpn'.format(num_blocks, ret))
        block_count = len(ret["stages"])
        assert num_blocks <= block_count, "use block {}, block count {}".format(
            num_blocks, block_count
        )
        blocks = range(block_count - num_blocks, block_count)
        ret = mbuilder.get_blocks(ret, block_indices=blocks)
    return ret["stages"]


class FBNetRPNHead(nn.Module):
    def __init__(
        self, cfg, in_channels, builder, arch_def,
    ):
        super(FBNetRPNHead, self).__init__()
        assert in_channels == builder.last_depth

        rpn_bn_type = cfg.MODEL.FBNET.RPN_BN_TYPE
        if len(rpn_bn_type) > 0:
            builder.bn_type = rpn_bn_type

        use_blocks = cfg.MODEL.FBNET.RPN_HEAD_BLOCKS
        stages = _get_rpn_stage(arch_def, use_blocks)

        self.head = builder.add_blocks(stages)
        self.out_channels = builder.last_depth

    def forward(self, x):
        x = [self.head(y) for y in x]
        return x


@registry.RPN_HEADS.register("FBNet.rpn_head")
def add_rpn_head(cfg, in_channels, num_anchors):
    builder, model_arch = create_builder(cfg)
    builder.last_depth = in_channels

    assert in_channels == builder.last_depth
    # builder.name_prefix = "[rpn]"

    rpn_feature = FBNetRPNHead(cfg, in_channels, builder, model_arch)
    rpn_regressor = rpn.RPNHeadConvRegressor(
        cfg, rpn_feature.out_channels, num_anchors)
    return nn.Sequential(rpn_feature, rpn_regressor)


def _get_head_stage(arch, head_name, blocks):
    # use default name 'head' if the specific name 'head_name' does not existed
    if head_name not in arch:
        head_name = "head"
    head_stage = arch.get(head_name)
    ret = mbuilder.get_blocks(arch, stage_indices=head_stage, block_indices=blocks)
    return ret["stages"]


# name mapping for head names in arch def and cfg
ARCH_CFG_NAME_MAPPING = {
    "bbox": "ROI_BOX_HEAD",
    "kpts": "ROI_KEYPOINT_HEAD",
    "mask": "ROI_MASK_HEAD",
}


class FBNetROIHead(nn.Module):
    def __init__(
        self, cfg, in_channels, builder, arch_def,
        head_name, use_blocks, stride_init, last_layer_scale,
    ):
        super(FBNetROIHead, self).__init__()
        assert in_channels == builder.last_depth
        assert isinstance(use_blocks, list)

        head_cfg_name = ARCH_CFG_NAME_MAPPING[head_name]
        self.pooler = poolers.make_pooler(cfg, head_cfg_name)

        stage = _get_head_stage(arch_def, head_name, use_blocks)

        assert stride_init in [0, 1, 2]
        if stride_init != 0:
            stage[0]["block"][3] = stride_init
        blocks = builder.add_blocks(stage)

        last_info = copy.deepcopy(arch_def["last"])
        last_info[1] = last_layer_scale
        last = builder.add_last(last_info)

        self.head = nn.Sequential(OrderedDict([
            ("blocks", blocks),
            ("last", last)
        ]))

        # output_blob = builder.add_final_pool(
        #     # model, output_blob, kernel_size=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION)
        #     model,
        #     output_blob,
        #     kernel_size=int(cfg.FAST_RCNN.ROI_XFORM_RESOLUTION / stride_init),
        # )

        self.out_channels = builder.last_depth

    def forward(self, x, proposals):
        x = self.pooler(x, proposals)
        x = self.head(x)
        return x


@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FBNet.roi_head")
def add_roi_head(cfg, in_channels):
    builder, model_arch = create_builder(cfg)
    builder.last_depth = in_channels
    # builder.name_prefix = "_[bbox]_"

    return FBNetROIHead(
        cfg, in_channels, builder, model_arch,
        head_name="bbox",
        use_blocks=cfg.MODEL.FBNET.DET_HEAD_BLOCKS,
        stride_init=cfg.MODEL.FBNET.DET_HEAD_STRIDE,
        last_layer_scale=cfg.MODEL.FBNET.DET_HEAD_LAST_SCALE,
    )


@registry.ROI_KEYPOINT_FEATURE_EXTRACTORS.register("FBNet.roi_head_keypoints")
def add_roi_head_keypoints(cfg, in_channels):
    builder, model_arch = create_builder(cfg)
    builder.last_depth = in_channels
    # builder.name_prefix = "_[kpts]_"

    return FBNetROIHead(
        cfg, in_channels, builder, model_arch,
        head_name="kpts",
        use_blocks=cfg.MODEL.FBNET.KPTS_HEAD_BLOCKS,
        stride_init=cfg.MODEL.FBNET.KPTS_HEAD_STRIDE,
        last_layer_scale=cfg.MODEL.FBNET.KPTS_HEAD_LAST_SCALE,
    )


@registry.ROI_MASK_FEATURE_EXTRACTORS.register("FBNet.roi_head_mask")
def add_roi_head_mask(cfg, in_channels):
    builder, model_arch = create_builder(cfg)
    builder.last_depth = in_channels
    # builder.name_prefix = "_[mask]_"

    return FBNetROIHead(
        cfg, in_channels, builder, model_arch,
        head_name="mask",
        use_blocks=cfg.MODEL.FBNET.MASK_HEAD_BLOCKS,
        stride_init=cfg.MODEL.FBNET.MASK_HEAD_STRIDE,
        last_layer_scale=cfg.MODEL.FBNET.MASK_HEAD_LAST_SCALE,
    )
