# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import functools
import logging
from abc import abstractmethod
from multiprocessing import Pool, cpu_count

import cv2
import mmcv
import numpy as np
from tqdm import tqdm

from easycv.datasets.registry import DATASOURCES
from easycv.file.image import load_image as _load_img
from easycv.framework.errors import NotImplementedError, ValueError


def load_image(img_path):
    img = _load_img(img_path, mode='BGR')
    result = {
        'img': img.astype(np.float32),
        'img_shape': img.shape,  # h, w, c
        'ori_shape': img.shape,
    }
    return result


def load_seg_map(seg_path, reduce_zero_label):
    gt_semantic_seg = _load_img(seg_path, mode='P')
    # reduce zero_label
    if reduce_zero_label:
        # avoid using underflow conversion
        gt_semantic_seg[gt_semantic_seg == 0] = 255
        gt_semantic_seg = gt_semantic_seg - 1
        gt_semantic_seg[gt_semantic_seg == 254] = 255

    return {'gt_semantic_seg': gt_semantic_seg}


def build_sample(source_item, classes, parse_fn, load_img, reduce_zero_label):
    """Build sample info from source item.
    Args:
        source_item: item of source iterator
        classes: classes list
        parse_fn: parse function to parse source_item, only accepts two params: source_item and classes
        load_img: load image or not, if true, cache all images in memory at init
    """
    result_dict = parse_fn(source_item, classes)

    if load_img:
        result_dict.update(load_image(result_dict['filename']))
        result_dict.update(
            load_seg_map(result_dict['seg_filename'], reduce_zero_label))

    return result_dict


@DATASOURCES.register_module
class SegSourceBase(object):
    """Data source for semantic segmentation.
        classes (str | list): classes list or file
        reduce_zero_label (bool): whether to mark label zero as ignored
        palette (Sequence[Sequence[int]]] | np.ndarray | None):
            palette of segmentation map, if none, random palette will be generated
        num_processes: number of processes to parse samples
        cache_at_init (bool): if set True, will cache in memory in __init__ for faster training
        cache_on_the_fly (bool): if set True, will cache in memroy during training
    """
    CLASSES = None
    PALETTE = None

    def __init__(self,
                 classes=None,
                 reduce_zero_label=False,
                 palette=None,
                 parse_fn=None,
                 num_processes=int(cpu_count() / 2),
                 cache_at_init=False,
                 cache_on_the_fly=False):

        if classes is not None:
            self.CLASSES = classes
        if palette is not None:
            self.PALETTE = palette

        self.reduce_zero_label = reduce_zero_label
        self.cache_at_init = cache_at_init
        self.cache_on_the_fly = cache_on_the_fly
        self.num_processes = num_processes

        if self.cache_at_init and self.cache_on_the_fly:
            raise ValueError(
                'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
            )
        assert isinstance(self.CLASSES, (str, tuple, list))
        if isinstance(self.CLASSES, str):
            self.CLASSES = mmcv.list_from_file(classes)
        if self.PALETTE is None:
            self.PALETTE = self.get_random_palette()

        source_iter = self.get_source_iterator()

        process_fn = functools.partial(
            build_sample,
            parse_fn=parse_fn,
            classes=self.CLASSES,
            load_img=cache_at_init == True,
            reduce_zero_label=self.reduce_zero_label)
        self.samples_list = self.build_samples(
            source_iter, process_fn=process_fn)
        self.num_samples = len(self.samples_list)
        # An error will be raised if failed to load _max_retry_num times in a row
        self._max_retry_num = self.num_samples
        self._retry_count = 0

    @abstractmethod
    def get_source_iterator():
        """Return data list iterator, source iterator will be passed to parse_fn,
        and parse_fn will receive params of item of source iter and classes for parsing.
        What does parse_fn need, what does source iterator returns.
        """
        raise NotImplementedError

    def build_samples(self, iterable, process_fn):
        samples_list = []
        with Pool(processes=self.num_processes) as p:
            with tqdm(total=len(iterable), desc='Scanning images') as pbar:
                for _, result_dict in enumerate(
                        p.imap_unordered(process_fn, iterable)):
                    if result_dict:
                        samples_list.append(result_dict)
                    pbar.update()

        return samples_list

    def __getitem__(self, idx):
        result_dict = self.samples_list[idx]
        load_success = True
        try:
            # avoid data cache from taking up too much memory
            if not self.cache_at_init and not self.cache_on_the_fly:
                result_dict = copy.deepcopy(result_dict)

            if not self.cache_at_init:
                if result_dict.get('img', None) is None:
                    result_dict.update(load_image(result_dict['filename']))
                if result_dict.get('gt_semantic_seg', None) is None:
                    result_dict.update(
                        load_seg_map(
                            result_dict['seg_filename'],
                            reduce_zero_label=self.reduce_zero_label))
                if self.cache_on_the_fly:
                    self.samples_list[idx] = result_dict
            result_dict = self.post_process_fn(copy.deepcopy(result_dict))
            self._retry_count = 0
        except Exception as e:
            logging.warning(e)
            load_success = False

        if not load_success:
            logging.warning(
                'Something wrong with current sample %s,Try load next sample...'
                % result_dict.get('filename', ''))
            self._retry_count += 1
            if self._retry_count >= self._max_retry_num:
                raise ValueError('All samples failed to load!')

            result_dict = self[(idx + 1) % self.num_samples]

        return result_dict

    def post_process_fn(self, result_dict):
        if result_dict.get('img_fields', None) is None:
            result_dict['img_fields'] = ['img']
        if result_dict.get('seg_fields', None) is None:
            result_dict['seg_fields'] = ['gt_semantic_seg']

        return result_dict

    def get_random_palette(self):
        # Get random state before set seed, and restore
        # random state later.
        # It will prevent loss of randomness, as the palette
        # may be different in each iteration if not specified.
        # See: https://github.com/open-mmlab/mmdetection/issues/5844
        state = np.random.get_state()
        np.random.seed(42)
        # random palette
        palette = np.random.randint(0, 255, size=(len(self.CLASSES), 3))
        np.random.set_state(state)

        return palette

    def __len__(self):
        return len(self.samples_list)
