# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import json
import os
import os.path as osp
import random
import time

import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint

from easycv.apis import set_random_seed
from easycv.datasets import build_dataloader, build_dataset
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.models import build_model
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.logger import get_root_logger


class ExtractProcess(object):

    def __init__(self, extract_list=['neck']):
        self.extract_list = extract_list
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def _forward_func(self, model, **kwargs):
        if hasattr(model.module, 'update_extract_list'):
            for k in self.extract_list:
                model.module.update_extract_list(k)

        feats = model(mode='extract', **kwargs)

        for k in self.extract_list:
            if type(feats[k]) is torch.Tensor:
                feats[k] = [feats[k]]
        feat_dict = {
            'feat{}'.format(i + 1): feat.cpu()
            for i, feat in enumerate(feats['neck'])
        }

        if 'gt_labels' in kwargs.keys():
            feat_dict['label'] = kwargs['gt_labels']
        return feat_dict

    def extract(self, model, data_loader, distributed=False):
        model.eval()
        func = lambda **x: self._forward_func(model, **x)

        if hasattr(data_loader, 'dataset'):
            length = len(data_loader.dataset)
        else:
            length = data_loader.data_length

        if distributed:
            rank, world_size = get_dist_info()
            results = dist_forward_collect(func, data_loader, rank, length)
        else:
            results = nondist_forward_collect(func, data_loader, length)
        return results


def parse_args():
    parser = argparse.ArgumentParser(
        description='EVTORCH batch（use dataloader） extract features of a model'
    )
    parser.add_argument(
        'config', help='config file path', type=str, default=None)
    parser.add_argument('--checkpoint', default=None, help='checkpoint file')
    parser.add_argument(
        '--pretrained',
        default='random',
        help='pretrained model file, exclusive to --checkpoint')
    parser.add_argument(
        '--work_dir',
        type=str,
        default=None,
        help='the dir to save logs and models')
    parser.add_argument(
        '--extract_list',
        type=list,
        default=['neck'],
        help='the dir to save logs and models')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--port',
        type=int,
        default=29500,
        help='port only works when launcher=="slurm"')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def main():
    args = parse_args()
    # set cudnn_benchmark
    cfg = mmcv_config_fromfile(args.config)

    if cfg.get('oss_io_config', None):
        io.access_oss(**cfg.oss_io_config)

    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    # checkpoint and pretrained are exclusive
    assert args.pretrained == 'random' or args.checkpoint is None, \
        'Checkpoint and pretrained are exclusive.'

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        if args.launcher == 'slurm':
            cfg.dist_params['port'] = args.port
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # logger
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, 'extract_{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    datasets = [build_dataset(cfg.data.extract)]
    seed = 0
    set_random_seed(seed)
    data_loader = [
        build_dataloader(
            ds,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus,
            dist=distributed,
            shuffle=False,
            replace=getattr(cfg.data, 'sampling_replace', False),
            seed=seed,
            drop_last=getattr(cfg.data, 'drop_last', False)) for ds in datasets
    ]

    # specify pretrained model
    if args.pretrained != 'random':
        assert isinstance(args.pretrained, str)
        cfg.model.pretrained = args.pretrained

    assert os.path.exists(args.checkpoint), \
        'checkpoint must be set when use extractor!'
    ckpt_meta = torch.load(args.checkpoint).get('meta', None)
    cfg_model = cfg.get('model', None)

    if cfg_model is not None:
        logger.info('load model scripts from cfg config')
        model = build_model(cfg_model)
    else:
        assert ckpt_meta is not None, 'extract need either cfg model or ckpt with meta!'
        logger.info('load model scripts from ckpt meta')
        ckpt_cfg = json.loads(ckpt_meta['config'])
        if 'model' not in ckpt_cfg:
            raise ValueError(
                'build model from %s, must use model after export' %
                (args.checkpoint))
        model = build_model(ckpt_cfg['model'])

    # build the model and load checkpoint

    if args.checkpoint is not None:
        logger.info('Use checkpoint: {} to extract features'.format(
            args.checkpoint))
        load_checkpoint(model, args.checkpoint, map_location='cpu')
    elif args.pretrained != 'random':
        logger.info('Use pretrained model: {} to extract features'.format(
            args.pretrained))
    else:
        logger.info('No checkpoint or pretrained is give, use random init.')

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)

    # build extraction processor
    extractor = ExtractProcess(extract_list=args.extract_list)

    # run
    outputs = extractor.extract(model, data_loader[0], distributed=distributed)

    rank, _ = get_dist_info()
    mmcv.mkdir_or_exist(args.work_dir)

    if rank == 0:
        for key, val in outputs.items():
            split_num = len(cfg.split_name)
            split_at = cfg.split_at
            print(split_num, split_at)
            for ss in range(split_num):
                output_file = '{}/{}_{}.npy'.format(args.work_dir,
                                                    cfg.split_name[ss], key)
                if ss == 0:
                    np.save(output_file, val[:split_at[0]])
                elif ss == split_num - 1:
                    np.save(output_file, val[split_at[-1]:])
                else:
                    np.save(output_file, val[split_at[ss - 1]:split_at[ss]])


if __name__ == '__main__':
    main()
