# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)

# This script is based on an original implementation by True Price.

import math
import os
import sqlite3
import sys
import tempfile
import typing as t
from struct import pack

import numpy as np
from opensfm import features
from opensfm import matching
from opensfm.dataset import DataSet

I_3 = np.eye(3)


def run_dataset(data: DataSet, binary: bool) -> None:
    """Export reconstruction to COLMAP format."""

    export_folder = os.path.join(data.data_path, "colmap_export")
    data.io_handler.mkdir_p(export_folder)

    database_path = os.path.join(export_folder, "colmap_database.db")
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_database_path = os.path.join(tmp_dir, "colmap_database.db")
        images_path = os.path.join(data.data_path, "images")

        db = COLMAPDatabase.connect(tmp_database_path)
        db.create_tables()

        images_map, camera_map = export_cameras(data, db)
        features_map = export_features(data, db, images_map)
        export_matches(data, db, features_map, images_map)

        if data.reconstruction_exists():
            export_ini_file(export_folder, database_path, images_path, data.io_handler)
            export_cameras_reconstruction(data, export_folder, camera_map, binary)
            points_map = export_points_reconstruction(
                data, export_folder, images_map, binary
            )
            export_images_reconstruction(
                data,
                export_folder,
                camera_map,
                images_map,
                features_map,
                points_map,
                binary,
            )
        db.commit()
        db.close()

        data.io_handler.rm_if_exist(database_path)
        with data.io_handler.open(tmp_database_path, "rb") as f:
            with data.io_handler.open(database_path, "wb") as fwb:
                fwb.write(f.read())


IS_PYTHON3: bool = int(sys.version_info[0]) >= 3

MAX_IMAGE_ID = 2 ** 31 - 1

CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
    camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    model INTEGER NOT NULL,
    width INTEGER NOT NULL,
    height INTEGER NOT NULL,
    params BLOB,
    prior_focal_length INTEGER NOT NULL)"""

CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""

CREATE_IMAGES_TABLE: str = """CREATE TABLE IF NOT EXISTS images (
    image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    name TEXT NOT NULL UNIQUE,
    camera_id INTEGER NOT NULL,
    prior_qw REAL,
    prior_qx REAL,
    prior_qy REAL,
    prior_qz REAL,
    prior_tx REAL,
    prior_ty REAL,
    prior_tz REAL,
    CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
    FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
""".format(
    MAX_IMAGE_ID
)

CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
CREATE TABLE IF NOT EXISTS two_view_geometries (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    config INTEGER NOT NULL,
    F BLOB,
    E BLOB,
    H BLOB)
"""

CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
"""

CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB)"""

CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"

CREATE_ALL: str = "; ".join(
    [
        CREATE_CAMERAS_TABLE,
        CREATE_IMAGES_TABLE,
        CREATE_KEYPOINTS_TABLE,
        CREATE_DESCRIPTORS_TABLE,
        CREATE_MATCHES_TABLE,
        CREATE_TWO_VIEW_GEOMETRIES_TABLE,
        CREATE_NAME_INDEX,
    ]
)


def image_ids_to_pair_id(image_id1, image_id2):
    if image_id1 > image_id2:
        image_id1, image_id2 = image_id2, image_id1
    return image_id1 * MAX_IMAGE_ID + image_id2


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % MAX_IMAGE_ID
    image_id1 = (pair_id - image_id2) // MAX_IMAGE_ID
    return image_id1, image_id2


def array_to_blob(array):
    if IS_PYTHON3:
        return array.tobytes()
    else:
        return np.getbuffer(array)


def blob_to_array(blob, dtype, shape=(-1,)):
    if IS_PYTHON3:
        return np.fromstring(blob, dtype=dtype).reshape(*shape)
    else:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)


class COLMAPDatabase(sqlite3.Connection):
    @staticmethod
    def connect(database_path):
        return sqlite3.connect(database_path, factory=COLMAPDatabase)

    def __init__(self, *args, **kwargs):
        super(COLMAPDatabase, self).__init__(*args, **kwargs)

        self.create_tables = lambda: self.executescript(CREATE_ALL)
        self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE)
        self.create_descriptors_table = lambda: self.executescript(
            CREATE_DESCRIPTORS_TABLE
        )
        self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE)
        self.create_two_view_geometries_table = lambda: self.executescript(
            CREATE_TWO_VIEW_GEOMETRIES_TABLE
        )
        self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
        self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE)
        self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)

    def add_camera(
        self, model, width, height, params, prior_focal_length=False, camera_id=None
    ):
        params = np.asarray(params, np.float64)
        cursor = self.execute(
            "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
            (
                camera_id,
                model,
                width,
                height,
                array_to_blob(params),
                prior_focal_length,
            ),
        )
        return cursor.lastrowid

    def add_image(
        self, name, camera_id, prior_q=(0, 0, 0, 0), prior_t=(0, 0, 0), image_id=None
    ):
        cursor = self.execute(
            "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (
                image_id,
                name,
                camera_id,
                prior_q[0],
                prior_q[1],
                prior_q[2],
                prior_q[3],
                prior_t[0],
                prior_t[1],
                prior_t[2],
            ),
        )
        return cursor.lastrowid

    def add_keypoints(self, image_id, keypoints):
        assert len(keypoints.shape) == 2
        assert keypoints.shape[1] in [2, 4, 6]

        keypoints = np.asarray(keypoints, np.float32)
        self.execute(
            "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
            (image_id,) + keypoints.shape + (array_to_blob(keypoints),),
        )

    def add_descriptors(self, image_id, descriptors):
        descriptors = np.ascontiguousarray(descriptors, np.uint8)
        self.execute(
            "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
            (image_id,) + descriptors.shape + (array_to_blob(descriptors),),
        )

    def add_matches(self, image_id1, image_id2, matches):
        assert len(matches.shape) == 2
        assert matches.shape[1] == 2

        if image_id1 > image_id2:
            matches = matches[:, ::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        self.execute(
            "INSERT INTO matches VALUES (?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches),),
        )

    def add_two_view_geometry(
        self, image_id1, image_id2, matches, F=I_3, E=I_3, H=I_3, config=2
    ):
        assert len(matches.shape) == 2
        assert matches.shape[1] == 2

        if image_id1 > image_id2:
            matches = matches[:, ::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        F = np.asarray(F, dtype=np.float64)
        E = np.asarray(E, dtype=np.float64)
        H = np.asarray(H, dtype=np.float64)
        self.execute(
            "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (pair_id,)
            + matches.shape
            + (
                array_to_blob(matches),
                config,
                array_to_blob(F),
                array_to_blob(E),
                array_to_blob(H),
            ),
        )


COLMAP_TYPES_MAP = {
    "brown": "FULL_OPENCV",
    "perspective": "RADIAL",
    "fisheye": "RADIAL_FISHEYE",
    "fisheye_opencv": "OPENCV_FISHEYE",
}
COLMAP_ID_MAP = {"brown": 6, "perspective": 3, "fisheye": 9, "fisheye_opencv": 5}


def camera_to_colmap_params(camera) -> t.Tuple[float, ...]:
    w = camera.width
    h = camera.height
    normalizer = max(w, h)
    f = camera.focal * normalizer
    if camera.projection_type in ("perspective", "fisheye"):
        k1 = camera.k1
        k2 = camera.k2
        cx = w * 0.5
        cy = h * 0.5
        return f, cx, cy, k1, k2
    elif camera.projection_type == "brown":
        fy = f * camera.aspect_ratio
        c_x = w * 0.5 + normalizer * camera.principal_point[0]
        c_y = h * 0.5 + normalizer * camera.principal_point[1]
        k1 = camera.k1
        k2 = camera.k2
        k3 = camera.k3
        p1 = camera.p1
        p2 = camera.p2
        return f, fy, c_x, c_y, k1, k2, p1, p2, k3, 0.0, 0.0, 0.0
    elif camera.projection_type == "fisheye_opencv":
        fy = f * camera.aspect_ratio
        cx = w * 0.5 + camera.principal_point[0]
        cy = h * 0.5 + camera.principal_point[1]
        k1 = camera.k1
        k2 = camera.k2
        k3 = camera.k3
        k4 = camera.k4
        return f, fy, cx, cy, k1, k2, k3, k4
    else:
        raise ValueError("Can't convert {camera.projection_type} to COLMAP")


def export_cameras(data, db):
    camera_map = {}
    for camera_model, camera in data.load_camera_models().items():
        if data.camera_models_overrides_exists():
            overrides = data.load_camera_models_overrides()
            if camera_model in overrides:
                camera = overrides[camera_model]

        parameters = camera_to_colmap_params(camera)
        camera_id = db.add_camera(
            COLMAP_ID_MAP[camera.projection_type],
            camera.width,
            camera.height,
            np.array(parameters),
        )
        camera_map[camera_model] = camera_id

    images_map = {}
    for image in data.images():
        camera_model = data.load_exif(image)["camera"]
        image_id = db.add_image(image, camera_map[camera_model])
        images_map[image] = image_id

    return images_map, camera_map


def export_features(data, db, images_map):
    features_map = {}
    for image in data.images():
        width = data.load_exif(image)["width"]
        height = data.load_exif(image)["height"]
        features_data = data.load_features(image)
        if not features_data:
            continue
        feat = features.denormalized_image_coordinates(
            features_data.points, width, height
        )
        features_map[image] = feat
        db.add_keypoints(images_map[image], feat)
    return features_map


def export_matches(data, db, features_map, images_map) -> None:
    matches_per_pair = {}
    for image1 in data.images():
        matches = data.load_matches(image1)
        for image2, image_matches in matches.items():
            pair_key = (min(image1, image2), max(image1, image2))
            pair_matches = matches_per_pair.setdefault(pair_key, {})
            for match in image_matches:
                if image1 < image2:
                    pair_matches.update({(match[0], match[1]): True})
                else:
                    pair_matches.update({(match[1], match[0]): True})

    data.config["robust_matching_threshold"] = 8
    for pair, matches in matches_per_pair.items():
        matches_numpy = np.array([np.array([m[0], m[1]]) for m in matches])
        if len(matches_numpy) < 10:
            continue
        F, inliers = matching.robust_match_fundamental(
            features_map[pair[0]], features_map[pair[1]], matches_numpy, data.config
        )
        if len(inliers) > 10:
            db.add_two_view_geometry(
                images_map[pair[0]], images_map[pair[1]], inliers, F=F
            )
            db.add_matches(images_map[pair[0]], images_map[pair[1]], inliers)


def export_cameras_reconstruction(data, path, camera_map, binary: bool=False) -> None:
    reconstructions = data.load_reconstruction()
    cameras = {}
    for reconstruction in reconstructions:
        for camera_id, camera in reconstruction.cameras.items():
            cameras[camera_id] = camera

    if binary:
        fout = data.io_handler.open(os.path.join(path, "cameras.bin"), "wb")
        fout.write(pack("<Q", len(cameras)))
    else:
        fout = data.io_handler.open_wt(os.path.join(path, "cameras.txt"))

    for camera_id, camera in cameras.items():
        colmap_id = camera_map[camera_id]
        colmap_type = COLMAP_TYPES_MAP[camera.projection_type]
        w = camera.width
        h = camera.height
        params = camera_to_colmap_params(camera)
        if binary:
            fout.write(pack("<2i", colmap_id, COLMAP_ID_MAP[camera.projection_type]))
            fout.write(pack("<2Q", w, h))
            fout.write(pack(f"<{len(params)}d", *params))
        else:
            str_out = "%d %s %d %d"
            for _param in params:
                str_out += " %f"
            str_out += "\n"
            fout.write(str_out % (colmap_id, colmap_type, w, h, *params))
    fout.close()


def export_images_reconstruction(
    data, path, camera_map, images_map, features_map, points_map, binary: bool=False
) -> None:
    reconstructions = data.load_reconstruction()
    tracks_manager = data.load_tracks_manager()

    if binary:
        fout = data.io_handler.open(os.path.join(path, "images.bin"), "wb")
        n_ims = 0
        for reconstruction in reconstructions:
            n_ims += len(reconstruction.shots)
        fout.write(pack("<Q", n_ims))
    else:
        fout = data.io_handler.open_wt(os.path.join(path, "images.txt"))

    for reconstruction in reconstructions:

        for shot_id, shot in reconstruction.shots.items():
            colmap_camera_id = camera_map[shot.camera.id]
            colmap_shot_id = images_map[shot_id]

            t = shot.pose.translation
            q = angle_axis_to_quaternion(shot.pose.rotation)

            if binary:
                fout.write(pack("<I", colmap_shot_id))
                fout.write(pack("<7d", *(list(q) + list(t))))
                fout.write(pack("<I", colmap_camera_id))
                for char in shot_id:
                    fout.write(pack("<c", char.encode("utf-8")))
                fout.write(pack("<c", b"\x00"))
            format_line = "%d %f %f %f %f %f %f %f %d %s\n"
            format_tuple = [
                colmap_shot_id,
                q[0],
                q[1],
                q[2],
                q[3],
                t[0],
                t[1],
                t[2],
                colmap_camera_id,
                shot_id,
            ]

            point_per_feat = {
                obs.id: k
                for k, obs in tracks_manager.get_shot_observations(shot_id).items()
            }

            points_tuple = []
            for feature_id in range(len(features_map[shot_id])):
                colmap_point_id = -1
                if feature_id in point_per_feat:
                    point_id = point_per_feat[feature_id]
                    if point_id in points_map:
                        colmap_point_id = points_map[point_id]

                if colmap_point_id != -1:
                    x, y = features_map[shot_id][feature_id]
                    format_line += "%f %f %d "
                    points_tuple += [x, y, colmap_point_id]
            format_line += "\n"

            if binary:
                fout.write(pack("<Q", len(points_tuple) // 3))
                for i in range(0, len(points_tuple), 3):
                    x, y, colmap_point_id = points_tuple[i : i + 3]
                    fout.write(pack("<2d", x, y))
                    fout.write(pack("<Q", colmap_point_id))
            else:
                fout.write(format_line % tuple(format_tuple + points_tuple))
    fout.close()


def export_points_reconstruction(data, path, images_map, binary: bool=False):
    reconstructions = data.load_reconstruction()
    tracks_manager = data.load_tracks_manager()

    points_map = {}

    if binary:
        fout = data.io_handler.open(os.path.join(path, "points3D.bin"), "wb")
        n_points = 0
        for reconstruction in reconstructions:
            n_points += len(reconstruction.points)
        fout.write(pack("<Q", n_points))
    else:
        fout = data.io_handler.open_wt(os.path.join(path, "points3D.txt"))

    i = 0
    for reconstruction in reconstructions:
        for point in reconstruction.points.values():
            c = point.coordinates
            cl = point.color
            format_line = "%d %f %f %f %d %d %d %f "
            format_tuple = [
                int(i),
                c[0],
                c[1],
                c[2],
                int(cl[0]),
                int(cl[1]),
                int(cl[2]),
                0.0,
            ]

            if binary:
                fout.write(pack("<Q", int(i)))
                fout.write(pack("<3d", c[0], c[1], c[2]))  # Position
                fout.write(pack("<3B", *[int(i) for i in cl]))  # Color
                fout.write(pack("<d", 0.0))  # Error

            track_tuple = []
            for image, obs in tracks_manager.get_track_observations(point.id).items():
                if image not in reconstruction.shots:
                    continue
                format_line += "%d %d "
                track_tuple += [images_map[image], obs.id]
            format_line += "\n"

            if binary:
                fout.write(pack("<Q", len(track_tuple) // 2))  # Track length
                for el in track_tuple:
                    fout.write(pack("<i", el))  # Track
            else:
                fout.write(format_line % tuple(format_tuple + track_tuple))
            points_map[point.id] = i
            i += 1
    fout.close()
    return points_map


def angle_axis_to_quaternion(angle_axis):
    angle = np.linalg.norm(angle_axis)

    x = angle_axis[0] / angle
    y = angle_axis[1] / angle
    z = angle_axis[2] / angle

    qw = math.cos(angle / 2.0)
    qx = x * math.sqrt(1 - qw * qw)
    qy = y * math.sqrt(1 - qw * qw)
    qz = z * math.sqrt(1 - qw * qw)

    return [qw, qx, qy, qz]


def export_ini_file(path, db_path, images_path, io_handler) -> None:
    with io_handler.open_wt(os.path.join(path, "project.ini")) as fout:
        fout.write("log_to_stderr=false\nlog_level=2\n")
        fout.write("database_path=%s\n" % db_path)
        fout.write("image_path=%s\n" % images_path)
