# coding=utf-8
# -*- encoding: utf-8 -*-
'''
@File    :   transformer_defaults.py
@Time    :   2022/06/01 21:44:17
@Author  :   Ming Ding 
@Contact :   dm18@mails.tsinghua.edu.cn
'''

import math
import torch
import torch.nn.functional as F

from sat import mpu

from sat.mpu.utils import split_tensor_along_last_dim
import contextlib

def standard_attention(query_layer, key_layer, value_layer, attention_mask,
                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
    # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
    # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 

    if scaling_attention_score:
        query_layer = query_layer / math.sqrt(query_layer.shape[-1])
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    if log_attention_weights is not None:
        attention_scores += log_attention_weights

    if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
        # if auto-regressive, skip
        attention_scores = torch.mul(attention_scores, attention_mask) - \
                           10000.0 * (1.0 - attention_mask)

    attention_probs = F.softmax(attention_scores, dim=-1)

    if attention_dropout is not None:
        if mpu.get_cuda_rng_tracker is not None:
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = attention_dropout(attention_probs)
        else:
            attention_probs = attention_dropout(attention_probs)

    context_layer = torch.matmul(attention_probs, value_layer)
    return context_layer

def attention_fn_default(query_layer, key_layer, value_layer, attention_mask,
                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
    # expand head dim to query dim, if necessary
    # only useful for multi-query attention
    batch_size, num_query_heads = query_layer.shape[:2] # [b, np, s, hn]
    num_kv_heads = key_layer.shape[1] # [b, np, s, hn]
    key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:])
    value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:])

    is_low_triangle = (attention_mask == torch.ones_like(attention_mask, dtype=torch.float).tril()).all()
    is_full = (attention_mask is None) or (attention_mask > 0).all()

    if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
        # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
        dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
        if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None:
            context = mpu.get_cuda_rng_tracker().fork()
        else:
            context = contextlib.nullcontext()
        with context:
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_layer, key_layer, value_layer, 
                attn_mask=None,
                dropout_p=dropout_p,
                is_causal=not is_full
            )
        return attn_output
    else:
        return standard_attention(
            query_layer, key_layer, value_layer, attention_mask,
            attention_dropout=attention_dropout, log_attention_weights=log_attention_weights,
            scaling_attention_score=scaling_attention_score, **kwargs
        )

def attention_forward_default(self, hidden_states, mask, **kw_args):
    self = self.transformer.layers[kw_args['layer_id']].attention
    attention_fn = attention_fn_default
    if 'attention_fn' in self.hooks:
        attention_fn = self.hooks['attention_fn']

    mixed_raw_layer = self.query_key_value(hidden_states)
    (mixed_query_layer,
        mixed_key_layer,
        mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, self.stride)

    dropout_fn = self.attention_dropout if self.training else None

    query_layer = self._transpose_for_scores(mixed_query_layer)
    key_layer = self._transpose_for_scores(mixed_key_layer)
    value_layer = self._transpose_for_scores(mixed_value_layer)

    # rotary position embedding 
    if self.transformer.is_rotary_emb:
        query_layer, key_layer = self.transformer.position_embeddings(
            query_layer, key_layer, kw_args['position_ids'],max_seqlen=kw_args['position_ids'].max()+1,
            layer_id=kw_args['layer_id']
        )

    context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
    context_layer = context_layer.view(*new_context_layer_shape)
    output = self.dense(context_layer)

    if self.training:
        output = self.output_dropout(output)
    return output

def cross_attention_forward_default(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
    self = self.transformer.layers[kw_args['layer_id']].cross_attention
    attention_fn = attention_fn_default
    if 'attention_fn' in self.hooks:
        attention_fn = self.hooks['attention_fn']

    mixed_query_layer = self.query(hidden_states)
    query_layer = self._transpose_for_scores(mixed_query_layer)
    dropout_fn = self.attention_dropout if self.training else None
    if isinstance(encoder_outputs, torch.Tensor):
        mixed_x_layer = self.key_value(encoder_outputs)
        (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
        # Reshape and transpose [b, np, s, hn]
        key_layer = self._transpose_for_scores(mixed_key_layer)
        value_layer = self._transpose_for_scores(mixed_value_layer)
        mem_cross = (key_layer, value_layer)
    else:
        key_layer, value_layer = encoder_outputs[kw_args['layer_id']]
        mem_cross = (key_layer, value_layer)

    context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, cross_attention=True, mem_cross=mem_cross, **kw_args)
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
    # [b, s, hp]
    context_layer = context_layer.view(*new_context_layer_shape)

    # Output. [b, s, h]
    output = self.dense(context_layer)
    if self.training:
        output = self.output_dropout(output)
    return output

def routing_forward_default(self, hidden_states, **kw_args):
    num_experts = self.transformer.num_experts
    # This is just an example that select 2 experts randomly.
    batch_size, sequence_length, hidden_dim = hidden_states.shape
    # router_logits: (batch * sequence_length, n_experts)
    router_logits = torch.randn((batch_size*sequence_length, num_experts), device=hidden_states.device, dtype=hidden_states.dtype)
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
    # we cast back to the input dtype
    routing_weights = routing_weights.to(hidden_states.dtype)
    return routing_weights, selected_experts

from functools import partial

def mlp_forward_default(self, hidden_states, expert_id=-1, **kw_args):
    if self.transformer.num_experts == 1 or expert_id > -1:
        self = self.transformer.layers[kw_args['layer_id']].mlp
        suffix = f"_{expert_id}" if expert_id > 0 else ""
        if self.is_gated_mlp:
            intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states)
            gated_intermediate_parallel = getattr(self, "dense_h_to_4h_gate"+suffix)(hidden_states)
            intermediate_parallel = self.activation_func(gated_intermediate_parallel) * intermediate_parallel
            output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel)
        else:
            intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states)
            intermediate_parallel = self.activation_func(intermediate_parallel)
            output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel)
        return output
    else:
        mlp_forward = self.hooks.get('mlp_forward', partial(mlp_forward_default, self))
        routing_forward = self.hooks.get('routing_forward', partial(routing_forward_default, self))
        self = self.transformer.layers[kw_args['layer_id']].mlp
        fwd_weight, fwd_idx = routing_forward(hidden_states, **kw_args)

        # Adapted from mixtral-8x7b https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(fwd_idx, num_classes=self.num_experts).permute(2, 1, 0)
        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            idx, top_x = torch.where(expert_mask[expert_idx])
            if top_x.shape[0] == 0:
                continue
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()
            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[top_x_list] # I don't know why using hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = mlp_forward(current_state, expert_id=expert_idx, **kw_args) * fwd_weight[top_x_list, idx_list, None]
            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        output = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
    return output

def word_embedding_forward_default(self, input_ids, output_cross_layer, **kw_args):
    return self.transformer.word_embeddings(input_ids)

def position_embedding_forward_default(self, position_ids, output_cross_layer, **kw_args):
    if not self.transformer.is_rotary_emb:
        return self.transformer.position_embeddings(position_ids)
    return None
        
from sat.mpu import gather_from_model_parallel_region
def final_forward_default(self, logits, **kw_args):
    logits_parallel = F.linear(logits, self.transformer.word_embeddings.weight)
    if not kw_args['parallel_output']:
        logits_parallel = gather_from_model_parallel_region(logits_parallel)
    return logits_parallel

def layer_forward_default(self, hidden_states, mask, *args, **kw_args):
    '''
        hidden_states: [batch, seq_len, hidden_size]
        mask: [(1, 1), seq_len, seq_len]
    '''
    self = self.transformer.layers[kw_args['layer_id']]
    
    # Layer norm at the begining of the transformer layer.
    attention_input = self.input_layernorm(hidden_states)
    # Self attention.
    attention_output = self.attention(attention_input, mask, **kw_args)

    # Third LayerNorm
    if self.layernorm_order == 'sandwich':
        attention_output = self.third_layernorm(attention_output)

    # DropPath for attention
    if self.training and self.drop_path > 0.:
        # drop_path percentage 0, others 1/(1-p)
        random_tensor = (1-self.drop_path
                            + torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path)
        attention_output = random_tensor.view(-1, 1, 1) * attention_output
    
    # Residual connection.
    if self.layernorm_order == 'post':
        hidden_states = attention_input + attention_output
        mlp_input = self.post_attention_layernorm(hidden_states)
    else:
        hidden_states = hidden_states + attention_output

    if self.is_decoder:
        encoder_outputs = kw_args['encoder_outputs']
        if encoder_outputs is not None:
            assert 'cross_attention_mask' in kw_args
            # Cross attention
            if self.layernorm_order == 'post':
                attention_output = self.cross_attention(mlp_input, **kw_args)
                # Residual connection.
                hidden_states = mlp_input + attention_output
                # Layer norm post the cross attention
                mlp_input = self.post_cross_attention_layernorm(hidden_states)
            else:
                cross_input = self.post_cross_attention_layernorm(hidden_states)
                attention_output = self.cross_attention(cross_input, **kw_args)
                hidden_states = hidden_states + attention_output

    if self.layernorm_order != 'post':
        mlp_input = self.post_attention_layernorm(hidden_states)    

    # MLP.
    mlp_output = self.mlp(mlp_input, **kw_args)

    # Fourth LayerNorm
    if self.layernorm_order == 'sandwich':
        mlp_output = self.fourth_layernorm(mlp_output)

    # DropPath for mlp
    if self.training and self.drop_path > 0.:
        random_tensor = (1-self.drop_path
                            + torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path)
        mlp_output = random_tensor.view(-1, 1, 1) * mlp_output

    # Second residual connection.
    if self.layernorm_order == 'post':
        output = mlp_input + mlp_output
    else:
        output = hidden_states + mlp_output

    return output

HOOKS_DEFAULT = {
    'attention_fn': attention_fn_default,
    'attention_forward': attention_forward_default,
    'cross_attention_forward': cross_attention_forward_default,
    'routing_forward': routing_forward_default,
    'mlp_forward': mlp_forward_default,
    'word_embedding_forward': word_embedding_forward_default,
    'position_embedding_forward': position_embedding_forward_default,
    'final_forward': final_forward_default,
    'layer_forward': layer_forward_default
}

ARGS_DEFAULT = {
    'embedding_dropout_prob': ('hidden_dropout', 0),
    'attention_dropout_prob': ('attention_dropout', 0),
    'output_dropout_prob': ('hidden_dropout', 0),
    'inner_hidden_size': ('inner_hidden_size', None),
    'hidden_size_per_attention_head': ('hidden_size_per_attention_head', None),
    'cross_hidden_size_per_attention_head': ('cross_hidden_size_per_attention_head', None),
    'checkpoint_activations': ('checkpoint_activations', False),
    'checkpoint_num_layers': ('checkpoint_num_layers', 1),
    'checkpoint_skip_layers': ('checkpoint_skip_layers', 0),
    'is_decoder': ('is_decoder', False),
    'cross_attn_hidden_size': ('cross_attn_hidden_size', None),
    'use_final_layernorm': ('use_final_layernorm', True),
    'layernorm_epsilon': ('layernorm_epsilon', 1e-5),
    'use_bias': ('use_bias', True),
    'use_qkv_bias': ('use_qkv_bias', False),
    'num_multi_query_heads': ('num_multi_query_heads', 0),
    'cross_num_multi_query_heads': ('cross_num_multi_query_heads', 0),
    'drop_path': ('drop_path', 0.),
    'row_parallel_linear_final_bias': ('row_parallel_linear_final_bias', True),
    'is_gated_mlp': ('is_gated_mlp', False),
    'is_rotary_emb': ('is_rotary_emb', False),
    'parallel_output': ('parallel_output', False),
    'num_experts': ('num_experts', 1),
}

from sat.ops.layernorm import LayerNorm, RMSNorm

NO_WD_MODULES = [LayerNorm, torch.nn.LayerNorm, RMSNorm]