# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

"""Functions for evaluating results computed for a json dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import json
import logging
import os
import uuid

import numpy as np
import six
from pycocotools.cocoeval import COCOeval


logger = logging.getLogger(__name__)


def evaluateMasks(
    json_dataset, all_boxes, all_segms, output_dir, use_salt=True, cleanup=False
):
    res_file = os.path.join(
        output_dir, "segmentations_" + json_dataset.name + "_results"
    )
    if use_salt:
        res_file += "_{}".format(str(uuid.uuid4()))
    res_file += ".json"
    _writeCocoSegmsResultsFile(json_dataset, all_boxes, all_segms, res_file)
    # Only do evaluation on non-test sets (annotations are undisclosed on test)
    if json_dataset.name.find("test") == -1:
        coco_eval = _doSegmentationEval(json_dataset, res_file, output_dir)
    else:
        logger.warning(
            "{} eval ignored as annotations are undisclosed on test: {} ignored".format(
                "Segmentation", json_dataset.name
            )
        )
        coco_eval = None
    # Optionally cleanup results json file
    if cleanup:
        os.remove(res_file)
    return coco_eval


def _writeCocoSegmsResultsFile(json_dataset, all_boxes, all_segms, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "segmentation": [...],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(json_dataset.classes):
        if cls == "__background__":
            continue
        if cls_ind >= len(all_boxes):
            break
        cat_id = json_dataset.category_to_id_map[cls]
        results.extend(
            _cocoSegmsResultsOneCategory(
                json_dataset, all_boxes[cls_ind], all_segms[cls_ind], cat_id
            )
        )
    logger.info(
        "Writing segmentation results json to: {}".format(os.path.abspath(res_file))
    )
    with open(res_file, "w") as fid:
        # "counts" is an array encoded by mask_util as a byte-stream. Python3's
        # json writer which /always produces strings/ cannot serialize a bytestream
        # unless you decode it. Thankfully, utf-8 works out (which is also what
        # the pycocotools/_mask.pyx does.
        if six.PY3:
            for r in results:
                rle = r["segmentation"]
                if "counts" in rle:
                    rle["counts"] = rle["counts"].decode("utf8")

        json.dump(results, fid)


def _cocoSegmsResultsOneCategory(json_dataset, boxes, segms, cat_id):
    results = []
    image_ids = json_dataset.COCO.getImgIds()
    image_ids.sort()
    assert len(boxes) == len(image_ids)
    assert len(segms) == len(image_ids)
    for i, image_id in enumerate(image_ids):
        dets = boxes[i]
        rles = segms[i]

        if isinstance(dets, list) and len(dets) == 0:
            continue

        dets = dets.astype(np.float)
        scores = dets[:, -1]

        results.extend(
            [
                {
                    "image_id": image_id,
                    "category_id": cat_id,
                    "segmentation": rles[k],
                    "score": scores[k],
                }
                for k in range(dets.shape[0])
            ]
        )

    return results


def _doSegmentationEval(json_dataset, res_file, output_dir):
    coco_dt = json_dataset.COCO.loadRes(str(res_file))
    coco_eval = COCOeval(json_dataset.COCO, coco_dt, "segm")
    coco_eval.evaluate()
    coco_eval.accumulate()
    _logDetectionEvalMetrics(json_dataset, coco_eval)
    # eval_file = os.path.join(output_dir, 'segmentation_results.pkl')
    # save_object(coco_eval, eval_file)
    # logger.info('Wrote json eval results to: {}'.format(eval_file))
    return coco_eval


def evaluateBoxes(json_dataset, all_boxes, output_dir, use_salt=True, cleanup=False):
    res_file = os.path.join(output_dir, "bbox_" + json_dataset.name + "_results")
    if use_salt:
        res_file += "_{}".format(str(uuid.uuid4()))
    res_file += ".json"
    _writeCocoBboxResultsFile(json_dataset, all_boxes, res_file)
    # Only do evaluation on non-test sets (annotations are undisclosed on test)
    if json_dataset.name.find("test") == -1:
        coco_eval = _doDetectionEval(json_dataset, res_file, output_dir)
    else:
        logger.warning(
            "{} eval ignored as annotations are undisclosed on test: {} ignored".format(
                "Bbox", json_dataset.name
            )
        )
        coco_eval = None
    # Optionally cleanup results json file
    if cleanup:
        os.remove(res_file)
    return coco_eval


def _writeCocoBboxResultsFile(json_dataset, all_boxes, res_file):
    # [{"image_id": 42,
    #   "category_id": 18,
    #   "bbox": [258.15,41.29,348.26,243.78],
    #   "score": 0.236}, ...]
    results = []
    for cls_ind, cls in enumerate(json_dataset.classes):
        if cls == "__background__":
            continue
        if cls_ind >= len(all_boxes):
            break
        cat_id = json_dataset.category_to_id_map[cls]
        results.extend(
            _cocoBboxEesultsOneCategory(json_dataset, all_boxes[cls_ind], cat_id)
        )
    logger.info("Writing bbox results json to: {}".format(os.path.abspath(res_file)))
    with open(res_file, "w") as fid:
        json.dump(results, fid)


def xyxy_to_xywh(xyxy):
    """Convert [x1 y1 x2 y2] box format to [x1 y1 w h] format."""
    if isinstance(xyxy, (list, tuple)):
        # Single box given as a list of coordinates
        assert len(xyxy) == 4
        x1, y1 = xyxy[0], xyxy[1]
        w = xyxy[2] - x1 + 1
        h = xyxy[3] - y1 + 1
        return (x1, y1, w, h)
    elif isinstance(xyxy, np.ndarray):
        # Multiple boxes given as a 2D ndarray
        return np.hstack((xyxy[:, 0:2], xyxy[:, 2:4] - xyxy[:, 0:2] + 1))
    else:
        raise TypeError("Argument xyxy must be a list, tuple, or numpy array.")


def _cocoBboxEesultsOneCategory(json_dataset, boxes, cat_id):
    results = []
    image_ids = json_dataset.COCO.getImgIds()
    image_ids.sort()
    assert len(boxes) == len(image_ids)
    for i, image_id in enumerate(image_ids):
        dets = boxes[i]
        if isinstance(dets, list) and len(dets) == 0:
            continue
        dets = dets.astype(np.float)
        scores = dets[:, -1]
        xywh_dets = xyxy_to_xywh(dets[:, 0:4])
        xs = xywh_dets[:, 0]
        ys = xywh_dets[:, 1]
        ws = xywh_dets[:, 2]
        hs = xywh_dets[:, 3]
        results.extend(
            [
                {
                    "image_id": image_id,
                    "category_id": cat_id,
                    "bbox": [xs[k], ys[k], ws[k], hs[k]],
                    "score": scores[k],
                }
                for k in range(dets.shape[0])
            ]
        )
    return results


def _doDetectionEval(json_dataset, res_file, output_dir):
    coco_dt = json_dataset.COCO.loadRes(str(res_file))
    coco_eval = COCOeval(json_dataset.COCO, coco_dt, "bbox")
    coco_eval.evaluate()
    coco_eval.accumulate()
    _logDetectionEvalMetrics(json_dataset, coco_eval)
    # eval_file = os.path.join(output_dir, 'detection_results.pkl')
    # save_object(coco_eval, eval_file)
    # logger.info('Wrote json eval results to: {}'.format(eval_file))
    return coco_eval


def _logDetectionEvalMetrics(json_dataset, coco_eval):
    def _get_thr_ind(coco_eval, thr):
        ind = np.where(
            (coco_eval.params.iouThrs > thr - 1e-5)
            & (coco_eval.params.iouThrs < thr + 1e-5)
        )[0][0]
        iou_thr = coco_eval.params.iouThrs[ind]
        assert np.isclose(iou_thr, thr)
        return ind

    IoU_lo_thresh = 0.5
    IoU_hi_thresh = 0.95
    ind_lo = _get_thr_ind(coco_eval, IoU_lo_thresh)
    ind_hi = _get_thr_ind(coco_eval, IoU_hi_thresh)
    # precision has dims (iou, recall, cls, area range, max dets)
    # area range index 0: all area ranges
    # max dets index 2: 100 per image
    precision = coco_eval.eval["precision"][ind_lo : (ind_hi + 1), :, :, 0, 2]
    ap_default = np.mean(precision[precision > -1])
    logger.info(
        "~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] ~~~~".format(
            IoU_lo_thresh, IoU_hi_thresh
        )
    )
    logger.info("{:.1f}".format(100 * ap_default))
    for cls_ind, cls in enumerate(json_dataset.classes):
        if cls == "__background__":
            continue
        # minus 1 because of __background__
        precision = coco_eval.eval["precision"][
            ind_lo : (ind_hi + 1), :, cls_ind - 1, 0, 2
        ]
        ap = np.mean(precision[precision > -1])
        logger.info("{:.1f}".format(100 * ap))
    logger.info("~~~~ Summary metrics ~~~~")
    coco_eval.summarize()
