# -*- encoding: utf-8 -*-
'''
@File    :   base_model.py
@Time    :   2021/10/01 22:40:33
@Author  :   Ming Ding 
@Contact :   dm18@mails.tsinghua.edu.cn
'''

# here put the import lib
from functools import partial
import os
import sys
import math
import random
import torch
import inspect
import warnings
import argparse
from sat.model.registry import model_registry, MetaModel

from sat.model.transformer import BaseTransformer, standard_attention
from sat.arguments import update_args_with_file, overwrite_args_by_dict, set_random_seed
from sat.training.model_io import load_checkpoint
from sat.helpers import print_rank0

from sat.transformer_defaults import HOOKS_DEFAULT, ARGS_DEFAULT
from sat.resources import auto_create
from sat.mpu.initialize import get_node_rank, get_model_parallel_rank, destroy_model_parallel, initialize_model_parallel
from sat.mpu.operation import mp_split_model_rank0, mp_split_model_receive, mp_merge_model_rank0, mp_merge_model_send
from sat.arguments import reset_random_seed

def non_conflict(func):
    '''mark a hook function as non-conflict,
    so that it can be compatible with any already defined hooks.
    e.g. PrefixTuningMixin.attention_fn
    '''
    func.non_conflict = True
    return func

def replacable(func):
    '''mark a hook function as replacable,
    so that it can be replaced by mixins added after it.
    e.g. FP32AttentionMixin.attention_fn
    '''
    func.replacable = True
    return func

class BaseMixin(torch.nn.Module):
    non_conflict = non_conflict
    replacable = replacable
    def __init__(self):
        super(BaseMixin, self).__init__()
        # define new params

    def reinit(self, parent_model=None):
        # reload the initial params from previous trained modules
        # you can also get access to other mixins through parent_model.get_mixin().
        pass

    # can define hook-functions here
    # a hook, if default or replacable, can be overrided by mixins added after it.
    # a hook can be augmented by non_conflict hooks added after it.
    # default -> 0~n replacable  -> 0~n non_conflict
    # ...

    # If the hook is just a pre- or post- transformation,
    # You can use @non_conflict to mark it,
    # and run `old_impl` to make it compatible with other mixins.
    # Eg., 
    # 
    # @non_conflict
    # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
    #     new_q, new_k, new_v = pre_hack(q, k, v)
    #     attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
    #     attn_result = post_hack(attn_result)
    #     return attn_result


class BaseModel(torch.nn.Module, metaclass=MetaModel):
    def __init__(self, args, transformer=None, params_dtype=torch.float, **kwargs):
        super(BaseModel, self).__init__()
        self.mixins = torch.nn.ModuleDict()
        self.collect_hooks_()
        if transformer is not None:
            self.transformer = transformer
        else:
            # check if model-only mode
            from sat.arguments import _simple_init
            success = _simple_init(model_parallel_size=args.model_parallel_size, seed=args.seed if hasattr(args, 'seed') else 1234)

            args_dict = {k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1]) for k, v in ARGS_DEFAULT.items()}

            self.transformer = BaseTransformer(
                num_layers=args.num_layers,
                vocab_size=args.vocab_size,
                hidden_size=args.hidden_size,
                num_attention_heads=args.num_attention_heads,
                max_sequence_length=args.max_sequence_length,
                layernorm_order=args.layernorm_order,
                **args_dict,
                hooks=self.hooks,
                params_dtype=params_dtype,
                skip_init=args.skip_init,
                device=torch.cuda.current_device() if hasattr(args, 'use_gpu_initialization') and args.use_gpu_initialization else torch.device('cpu'),
                **kwargs
            )

    def reinit(self, mixin_names=None):  # will be called when loading model, None means all
        # if some mixins are loaded, overrides this function
        for k, m in self.mixins.items():
            if mixin_names is None or k in mixin_names:
                m.reinit(self)

    def add_mixin(self, name, new_mixin, reinit=False):
        assert name not in self.mixins
        assert isinstance(new_mixin, BaseMixin)

        self.mixins[name] = new_mixin  # will auto-register parameters
        object.__setattr__(new_mixin, 'transformer', self.transformer)  # cannot use pytorch set_attr

        self.collect_hooks_()
        if reinit:
            new_mixin.reinit(self)  # also pass current mixins

    def del_mixin(self, name):
        assert name in self.mixins
        del self.mixins[name]
        self.collect_hooks_()

    def get_mixin(self, name):
        return self.mixins[name]

    def forward(self, *args, **kwargs):
        # update hooks as the current model (overrided forwards)
        # Attention! the transformer might be shared by multiple models
        self.transformer.hooks.clear()
        self.transformer.hooks.update(self.hooks)
        return self.transformer(*args, **kwargs)

    def collect_hooks_(self):
        names = list(HOOKS_DEFAULT.keys())
        hooks = {}
        hook_origins = {}
        for name in names:
            if hasattr(self, name):
                hooks[name] = getattr(self, name)
                hook_origins[name] = 'model'

            for mixin_name, m in self.mixins.items():
                if hasattr(m, name):
                    if hasattr(getattr(m, name), 'non_conflict'):
                        # check getattr(m, name), who must accept old_impl as an argument
                        signature = inspect.signature(getattr(m, name))
                        if 'old_impl' not in signature.parameters:
                            raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.')
                        # -------------
                        if name in hooks:
                            old_impl = hooks[name]
                        elif name == 'attention_fn': # the only hook without self
                            old_impl = HOOKS_DEFAULT[name]
                        else:
                            old_impl = partial(HOOKS_DEFAULT[name], self) # relax! `partial` does not affect the signature
                        old_origin = hook_origins.get(name, 'default')
                        hooks[name] = partial(getattr(m, name), old_impl=old_impl)
                        hook_origins[name] = mixin_name + ' -> ' + old_origin
                    elif name in hooks and not hasattr(hooks[name], 'replacable'): # if this hook name is already registered
                        raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
                    else: # new hook
                        if name in hooks and hasattr(hooks[name], 'replacable'):
                            warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.')
                        hooks[name] = getattr(m, name)
                        hook_origins[name] = mixin_name

        self.hooks = hooks
        self.hook_origins = hook_origins
        return hooks

    def disable_untrainable_params(self):
        pass

    @classmethod
    def add_model_specific_args(cls, parser):
        # recorded in arguments.py: add_model_config_args
        return parser

    @classmethod
    def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
        '''Load a pretrained checkpoint of the current model.
            Args:
                name: The identifier of the pretrained model.
                args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
                path: the parent folder of existing `name` model. Default: SAT_HOME.
                url: the url of the model. Default: SAT_URL.
                prefix: the prefix of the checkpoint. Default: ''.
            Returns:
                model: the loaded model.
                args: the loaded args.
        '''
        if os.path.exists(name) and os.path.isdir(name):
            model_path = name
        else:
            model_path = auto_create(name, path=home_path, url=url)
        # create a new args if not provided
        if args is None:
            args = cls.get_args()
        args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
        args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
        specific_iteration = kwargs.pop('specific_iteration', None)
        model = get_model(args, cls, **kwargs)
        if not build_only:
            load_checkpoint(model, args, load_path=model_path, prefix=prefix, specific_iteration=specific_iteration)
        return model, args
    
    @classmethod
    def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
        if build_only or 'model_parallel_size' not in overwrite_args:
            return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
        else:
            new_model_parallel_size = overwrite_args['model_parallel_size']
            if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
                model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
                local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
                destroy_model_parallel()
                initialize_model_parallel(1)
                if local_rank == 0:
                    args.skip_init = True
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args.pop('model_parallel_size')
                    model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
                    if args_.model_parallel_size != 1:
                        raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
                if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
                    torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(new_model_parallel_size)
                if local_rank == 0:
                    mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
                    del model_full
                else:
                    mp_split_model_receive(model, use_node_group=use_node_group)
                reset_random_seed(6)
            else:
                overwrite_args.pop('model_parallel_size')
                model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
                rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
                destroy_model_parallel()
                initialize_model_parallel(1)
                if rank == 0:
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args['model_parallel_size'] = 1
                    model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
                torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(model_args.model_parallel_size)
                if rank == 0:
                    mp_merge_model_rank0(model, model_full)
                    model, model_args = model_full, args_
                else:
                    mp_merge_model_send(model)
                    model_args.model_parallel_size = 1
                destroy_model_parallel()
                initialize_model_parallel(1)
            return model, model_args
    
    @classmethod
    def list_avail_args(cls, print=True):
        '''List all available args of the current model.'''
        parser = argparse.ArgumentParser()
        from sat.arguments import add_model_config_args
        add_model_config_args(parser)
        # add args of the current model
        if hasattr(cls, 'add_model_specific_args'):
            cls.add_model_specific_args(parser)
        if print:
            from sat.helpers import print_parser
            print_parser(parser)
        return parser

    @classmethod
    def get_args(cls, **kwargs):
        '''Get the parsed args of the current model.
            Args:
                **kwargs: will override the default args.
            Returns:
                args: the parsed args.
        '''
        parser = cls.list_avail_args(print=False)
        # use parser to parse kwargs
        args = parser.parse_args([])
        for k, v in kwargs.items():
            if hasattr(args, k) or k in ['fp16']: # non-arch args but affect building models
                setattr(args, k, v)
            else:
                print_rank0(f'warning: Unknown arg {k} for class {cls.__name__}.', level='DEBUG')
                setattr(args, k, v)
        return args

class AutoModel():
    @classmethod
    def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
        '''Automatically find the class and instantiate it. Auto-download.
            Args:
                name: The identifier of the pretrained model.
                args: NameSpace. will add the loaded args into it.
                path: the parent folder of existing `name` model. Default: SAT_HOME.
                url: manually specified url for the `name` model.
        '''
        if os.path.exists(name) and os.path.isdir(name):
            model_path = name
        else:
            model_path = auto_create(name, path=home_path, url=url)
        if args is None:
            args = argparse.Namespace() # null, fill later
            null_args = True
        else:
            null_args = False
        args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
        args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
        if not hasattr(args, 'model_class'):
            raise ValueError('model_config.json must have key "model_class" for AutoModel.from_pretrained.')
        model_cls = model_registry.get(args.model_class)
        if null_args:
            # fill args with default values, if not provided
            model_default_args = model_cls.get_args()
            for k, v in model_default_args.__dict__.items():
                if not hasattr(args, k):
                    setattr(args, k, v)
        model = get_model(args, model_cls, **kwargs)
        if not build_only:
            load_checkpoint(model, args, load_path=model_path, prefix=prefix)
        return model, args
    
    @classmethod
    def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
        if build_only or 'model_parallel_size' not in overwrite_args:
            return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
        else:
            new_model_parallel_size = overwrite_args['model_parallel_size']
            if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
                model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
                local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
                destroy_model_parallel()
                initialize_model_parallel(1)
                if local_rank == 0:
                    args.skip_init = True
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args.pop('model_parallel_size')
                    model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
                    if args_.model_parallel_size != 1:
                        raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
                if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
                    torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(new_model_parallel_size)
                if local_rank == 0:
                    mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
                    del model_full
                else:
                    mp_split_model_receive(model, use_node_group=use_node_group)
                reset_random_seed(6)
            else:
                overwrite_args.pop('model_parallel_size')
                model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
                rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()
                assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
                destroy_model_parallel()
                initialize_model_parallel(1)
                if rank == 0:
                    args.use_gpu_initialization = False
                    args.device = 'cpu'
                    overwrite_args['model_parallel_size'] = 1
                    model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
                torch.distributed.barrier()
                destroy_model_parallel()
                initialize_model_parallel(model_args.model_parallel_size)
                if rank == 0:
                    mp_merge_model_rank0(model, model_full)
                    model, model_args = model_full, args_
                else:
                    mp_merge_model_send(model)
                    model_args.model_parallel_size = 1
                destroy_model_parallel()
                initialize_model_parallel(1)
            return model, model_args
    
def get_model(args, model_cls, **kwargs):
    """Build the model."""
    import torch
    from sat.helpers import print_rank0,print_all
    from sat import mpu

    print_rank0(f'building {model_cls.__name__} model ...')
    if 'params_dtype' not in kwargs:
        if hasattr(args, 'fp16') and args.fp16:
            params_dtype = torch.half
        elif hasattr(args, 'bf16') and args.bf16:
            params_dtype = torch.bfloat16
        else:
            params_dtype = torch.float32
    else:
        # pop params_dtype from kwargs
        params_dtype = kwargs.pop('params_dtype')

    from sat.helpers import check_if_zero3
    if check_if_zero3(args):
        import deepspeed
        with deepspeed.zero.Init():
            model = model_cls(args, params_dtype=params_dtype, **kwargs)
    else:
        model = model_cls(args, params_dtype=params_dtype, **kwargs)

    if mpu.get_data_parallel_rank() == 0:
        print_all(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)
    
    if hasattr(args, 'fp16') and args.fp16:
        model.half()
    elif hasattr(args, 'bf16') and args.bf16:
        model.bfloat16()

    try: # TODO: is this useful?
        if not hasattr(args, 'device'):
            args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        model = model.to(args.device)
    except Exception as e:
        print_all(e)
    
    return model
