import functools

from maga_transformer.utils.model_weight import W,  CkptWeightInfo,\
    identity, zeros, transpose, trans_qkv, trans_qkv_b

from maga_transformer.model_loader.model_weight_info import ModelWeightInfo, ModelDeployWeightInfo
from maga_transformer.model_loader.weight_module import WeightModule, AtomicWeight
from maga_transformer.model_loader.ffn_weight import FfnAtomicWeight, FfnWeight, FfnConfig
from maga_transformer.model_loader.attn_weight import AttnAtomicWeight, AttnConfig

class GPTNeoxWeight(ModelDeployWeightInfo):
    def __init__(self, config, tp_size, tp_rank):
        super().__init__(config, tp_size, tp_rank)
        self.norm = config.norm_type

    def _get_weight_info(self):
        weights = [
            AtomicWeight(W.embedding, [CkptWeightInfo('gpt_neox.embed_in.weight', identity)], identity),
            AtomicWeight(W.lm_head, [CkptWeightInfo('embed_out.weight', identity)], identity)
        ]

        attn_config: AttnConfig = self.attn_config
        ffn_config: FfnConfig = self.ffn_config
        layer_weights = []
        for _ in range(self._num_layers):
            layer_weight = [
                AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo('gpt_neox.layers.{i}.attention.query_key_value.weight', identity)],
                        functools.partial(trans_qkv, hidden_size=self._hidden_size, head_num=self._head_num), config=attn_config),

                AttnAtomicWeight(W.attn_qkv_b, [CkptWeightInfo('gpt_neox.layers.{i}.attention.query_key_value.bias', identity)],
                        functools.partial(trans_qkv_b, hidden_size=self._hidden_size, head_num=self._head_num), config=attn_config),

                AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo('gpt_neox.layers.{i}.attention.dense.weight', identity)],
                        transpose, config=attn_config),

                AttnAtomicWeight(W.attn_o_b, [CkptWeightInfo('gpt_neox.layers.{i}.attention.dense.bias', identity)],
                        identity, config=attn_config),
                FfnWeight(sub_weights=[
                    FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight', identity)],
                            transpose, config=ffn_config),

                    FfnAtomicWeight(W.ffn_b3, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias', identity)],
                            identity, config=ffn_config),

                    FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight', identity)],
                            transpose, config=ffn_config),

                    FfnAtomicWeight(W.ffn_b2, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias', identity)],
                            identity, config=ffn_config)],
                          config=ffn_config)
            ]

            # default use parallel residual: x = x + attn(ln1(x)) + mlp(ln2(x))

            if self.norm == 'rmsnorm':
                weights.extend([
                    AtomicWeight(W.final_ln_gamma, [CkptWeightInfo('gpt_neox.final_layer_norm.scale', identity)], identity),
                    AtomicWeight(W.final_ln_beta, [], functools.partial(zeros, shape=[self._hidden_size]))
                ])
                layer_weights.extend([
                    AtomicWeight(W.pre_attn_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.input_layernorm.scale', identity)],
                            identity),
                    AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.post_attention_layernorm.scale', identity)],
                            identity)
                ])
            elif self.norm == 'layernorm':
                weights.extend([
                    AtomicWeight(W.final_ln_gamma, [CkptWeightInfo('gpt_neox.final_layer_norm.weight', identity)], identity),
                    AtomicWeight(W.final_ln_beta, [CkptWeightInfo('gpt_neox.final_layer_norm.bias', identity)], identity)
                ])
                layer_weights.extend([
                    AtomicWeight(W.pre_attn_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.input_layernorm.weight', identity)],
                            identity),
                    AtomicWeight(W.pre_attn_ln_beta, [CkptWeightInfo('gpt_neox.layers.{i}.input_layernorm.bias', identity)],
                            identity),
                    AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.post_attention_layernorm.weight', identity)],
                            identity),
                    AtomicWeight(W.pre_ln_beta, [CkptWeightInfo('gpt_neox.layers.{i}.post_attention_layernorm.bias', identity)],
                            identity)
                ])
                layer_weights.append(layer_weight)
        return ModelWeightInfo(layer_weights=layer_weights, weights=weights)

class GPTNeox13BWeight(ModelDeployWeightInfo):
    def _get_weight_info(self):
        weights = [
            AtomicWeight(W.embedding, [CkptWeightInfo('gpt_neox.embed_in.weight', identity)], identity),
            AtomicWeight(W.lm_head, [CkptWeightInfo('embed_out.weight', identity)], identity),
            AtomicWeight(W.final_ln_gamma, [CkptWeightInfo('gpt_neox.final_layer_norm.scale', identity)], identity),
            AtomicWeight(W.final_ln_beta, [], functools.partial(zeros, shape=[self._hidden_size])),
        ]
        attn_config: AttnConfig = self.attn_config
        ffn_config: FfnConfig = self.ffn_config

        layer_weights = []
        for _ in range(self._num_layers):
            layer_weight= [
                AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.input_layernorm.scale', identity)],
                        identity),

                AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo('gpt_neox.layers.{i}.attention.query_key_value.weight', identity)],
                        functools.partial(trans_qkv, hidden_size=self._hidden_size, head_num=self._head_num), config=attn_config),

                AttnAtomicWeight(W.attn_qkv_b, [CkptWeightInfo('gpt_neox.layers.{i}.attention.query_key_value.bias', identity)],
                        functools.partial(trans_qkv_b, hidden_size=self._hidden_size, head_num=self._head_num), config=attn_config),

                AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo('gpt_neox.layers.{i}.attention.dense.weight', identity)],
                        transpose, config=attn_config),

                AttnAtomicWeight(W.attn_o_b, [CkptWeightInfo('gpt_neox.layers.{i}.attention.dense.bias', identity)],
                        identity, config=attn_config),
                FfnWeight(sub_weights=[
                    FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight', identity)],
                            transpose, config=ffn_config),

                    FfnAtomicWeight(W.ffn_b3, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias', identity)],
                            identity, config=ffn_config),

                    FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight', identity)],
                            transpose, config=ffn_config),

                    FfnAtomicWeight(W.ffn_b2, [CkptWeightInfo('gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias', identity)],
                            identity, config=ffn_config)
                ], config=ffn_config),

                AtomicWeight(W.post_ln_gamma, [CkptWeightInfo('gpt_neox.layers.{i}.post_attention_layernorm.scale', identity)],
                        identity),
            ]
            layer_weights.append(layer_weight)

        return ModelWeightInfo(layer_weights=layer_weights, weights=weights)