# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mmap
import os
import tempfile
from collections import OrderedDict
from urllib.parse import urlparse

import numpy as np
from mars.lib.filesystem import FileSystem

from ...core import ODPS
from ...compat import futures, BytesIO
from ...errors import NoSuchObject
from ...tunnel.volumetunnel import VolumeTunnel
from ...utils import to_binary, to_str


DIR_ERROR_MSG = (
    "Argument `path` illegal, expect to be "
    "odps:///**project**/volumes/**volume**/, "
    "got {}"
)

FILE_ERROR_MSG = (
    "Argument `path` illegal, expect to be "
    "odps:///**project**/volumes/**volume**/**file**, "
    "got {}"
)

DEFAULT_CHUNK_SIZE = 512 * 1024**2

META_FILE_NAME = "_meta_"


class _VolumeFileObject(object):
    def __init__(self, volume_parted, mode, temp_dir=None, chunk_size=None):
        self._volume_parted = volume_parted
        self._mode = mode
        self._temp_dir = temp_dir
        self._chunk_size = chunk_size

        # writer
        self._index = 0
        self._writer = None
        self._write_buffer = None
        self._meta = dict()

        # reader
        self._reader = None
        self._read_filename = None

    def _read_multi_volume_files(self):
        # get meta file first
        with self._volume_parted.open_reader(META_FILE_NAME) as meta_reader:
            content = to_str(meta_reader.read())
            meta = OrderedDict()
            length = 0
            for line in content.split("\n"):
                if "," in line:
                    filename, size = line.split(",", 1)
                    size = int(size)
                    meta[filename] = size
                    length += size

        temp_dir = self._temp_dir or tempfile.mkdtemp()
        write_filename = os.path.join(
            temp_dir, self._volume_parted.volume.name, self._volume_parted.name
        )
        os.makedirs(os.path.dirname(write_filename), exist_ok=True)

        with open(write_filename, "wb") as f:
            for size in meta.values():
                f.write(b"\0" * size)

        with open(write_filename, "rb+") as f:
            m = mmap.mmap(f.fileno(), length)
            offsets = [0] + np.cumsum(list(meta.values())).tolist()

            volume_part = self._volume_parted
            volume_tunnel = VolumeTunnel(
                client=volume_part._client,
                project=volume_part.project,
                endpoint=volume_part.project._tunnel_endpoint,
            )

            def _write(filename, start, end):
                with volume_tunnel.create_download_session(
                    volume_part.volume.name, volume_part.name, filename
                ).open() as reader:
                    content_part = reader.read()
                    m[start:end] = content_part

            executor = futures.ThreadPoolExecutor(8)
            fs = []
            for i, filename in enumerate(meta):
                future = executor.submit(_write, filename, offsets[i], offsets[i + 1])
                fs.append(future)
            [f.result() for f in fs]

        self._read_filename = write_filename
        self._reader = open(write_filename, "rb")

    def read(self, size=None):
        self._volume_parted.reload()
        if self._volume_parted.file_number == 1:
            if self._reader is None:
                self._reader = self._volume_parted.open_reader(
                    list(self._volume_parted.files)[0].name
                )
            return self._reader.read(size=size)
        else:
            if self._reader is None:
                self._read_multi_volume_files()
            return self._reader.read(size)

    def _ensure_writer_open(self):
        if self._writer is None:
            self._writer = self._volume_parted.open_writer()
            self._write_buffer = BytesIO()

    def _flush_to_volume_file(self):
        # write buffer full, write with a file
        filename = str(self._index)
        buffer = self._write_buffer
        content_part = buffer.getvalue()
        self._writer.write(filename, content_part)
        self._meta[filename] = buffer.tell()
        self._index += 1
        # clear write buffer
        self._write_buffer = BytesIO()

    def write(self, content):
        self._ensure_writer_open()

        content_left = len(content)
        while content_left > 0:
            buffer_left_size = self._chunk_size - self._write_buffer.tell()
            content_part = content[:buffer_left_size]
            content = content[buffer_left_size:]
            self._write_buffer.write(content_part)
            content_left -= len(content_part)
            if self._write_buffer.tell() >= self._chunk_size:
                # write buffer full
                self._flush_to_volume_file()

    def close(self):
        if self._writer is not None:
            if self._write_buffer.tell() > 0:
                self._flush_to_volume_file()
                self._index = 0
            if len(self._meta) > 1:
                meta_content = "\n".join(
                    "{},{}".format(filename, size)
                    for filename, size in self._meta.items()
                )
                # write meta file if file length > 1
                self._writer.write(META_FILE_NAME, to_binary(meta_content))
            self._writer.__exit__(None, None, None)
        if self._reader is not None:
            if self._read_filename is not None:
                # remove read file name
                os.remove(self._read_filename)
            self._reader.__exit__(None, None, None)

    def __enter__(self):
        if self._writer or self._reader:
            (self._writer or self._reader).__enter__()
        return self

    def __exit__(self, *_):
        self.close()


class VolumeFileSystem(FileSystem):
    """
    Schema follows:

    /**project**/volumes/**volume**
    """

    def __init__(self, odps=None, path=None, temp_dir=None, chunk_size=None):
        if odps is None:
            # try to get from environments
            odps = ODPS.from_environments()
        if odps is None:
            # try to get from global config
            odps = ODPS.from_global()
        if odps is None:
            # still None, raise error
            raise ValueError("`odps` should be specified for VolumeFileSystem")

        self._odps = odps
        self._path = path
        self._temp_dir = temp_dir
        self._chunk_size = chunk_size or DEFAULT_CHUNK_SIZE

    @staticmethod
    def parse_from_path(uri):
        parsed = urlparse(uri)
        options = dict()
        if parsed.path:
            options["path"] = parsed.path.strip("/")
        return options

    @staticmethod
    def _extract_info(path, is_file=False):
        options = VolumeFileSystem.parse_from_path(path)
        error_msg = DIR_ERROR_MSG if not is_file else FILE_ERROR_MSG

        if "path" not in options:
            raise ValueError(error_msg.format(path))

        path = options["path"]
        splits = path.split("/")
        expect_size = 3 if not is_file else 4
        if len(splits) != expect_size:
            raise ValueError(error_msg.format(path))

        volumes = splits[1]
        if volumes.lower() != "volumes":
            raise ValueError(error_msg.format(path))

        return splits[:1] + splits[2:]

    def ls(self, path):
        path = path.strip("/") or self._path
        project, volume = self._extract_info(path)
        try:
            files = self._odps.list_volume_partitions(volume, project=project)
        except NoSuchObject:
            raise FileNotFoundError("File {} not found".format(path))
        return [
            "odps:///{}/volumes/{}/{}".format(project, volume, f.name) for f in files
        ]

    def delete(self, path, recursive: bool = False):
        path = path.strip("/") or self._path
        if "/" not in path:
            path = "{}/{}".format(self._path, path)
        options = VolumeFileSystem.parse_from_path(path)
        fpath = options["path"]

        splits = fpath.split("/")
        if len(splits) == 3:
            is_file = False
        elif len(splits) == 4:
            is_file = True
        else:
            raise ValueError(
                "Argument `path` illegal, expect to be "
                "odps:///**project**/volumes/**volume**/ or"
                "odps:///**project**/volumes/**volume**/**partition**/, "
                "got {}".format(path)
            )

        if is_file:
            project, volume, partition = self._extract_info(path, is_file=True)
            if self._odps.exist_volume_partition(
                volume, partition=partition, project=project
            ):
                self._odps.delete_volume_partition(
                    volume, partition=partition, project=project
                )
        else:
            project, volume = self._extract_info(path)
            if self._odps.exist_volume(volume, project=project):
                self._odps.delete_volume(volume, project=project)

    def open(self, path, mode="rb"):
        if mode not in ("rb", "wb"):
            raise ValueError("`mode` can be `rb` or `wb` only")

        path = path.strip("/") or self._path
        if "/" not in path:
            path = "{}/{}".format(self._path, path)
        project, volume, partition = self._extract_info(path, is_file=True)
        volume_parted = self._odps.get_volume_partition(
            volume, partition=partition, project=project
        )

        return _VolumeFileObject(
            volume_parted, mode, temp_dir=self._temp_dir, chunk_size=self._chunk_size
        )

    def mkdir(self, path, create_parents=True):
        assert create_parents

        path = path.strip("/") or self._path
        project, volume = self._extract_info(path)

        if not self._odps.exist_volume(volume, project=project):
            self._odps.create_parted_volume(volume, project=project)

    def stat(self, path):
        path = path.strip("/") or self._path
        if "/" not in path:
            path = "{}/{}".format(self._path, path)
        options = VolumeFileSystem.parse_from_path(path)
        fpath = options["path"]

        splits = fpath.split("/")
        if len(splits) == 3:
            is_file = False
        elif len(splits) == 4:
            is_file = True
        else:
            raise ValueError(
                "Argument `path` illegal, expect to be "
                "odps:///**project**/volumes/**volume**/ or"
                "odps:///**project**/volumes/**volume**/**partition**/, "
                "got {}".format(path)
            )

        if is_file:
            project, volume, partition = self._extract_info(path, is_file=True)
            volume_parted = self._odps.get_volume_partition(
                volume, partition=partition, project=project
            )
            stat = dict(
                name=path,
                size=volume_parted.length,
                created=volume_parted.creation_time,
            )
        else:
            project, volume = self._extract_info(path)
            vol = self._odps.get_volume(volume, project=project)
            stat = dict(name=path, size=vol.length, created=vol.creation_time)
        return stat

    def _isfilestore(self):
        return True

    def cat(self, path):
        raise NotImplementedError

    def exists(self, path):
        raise NotImplementedError

    def glob(self, path, recursive=False):
        raise NotImplementedError

    def isdir(self, path):
        raise NotImplementedError

    def isfile(self, path):
        raise NotImplementedError

    def rename(self, path, new_path):
        raise NotImplementedError

    def walk(self, path):
        raise NotImplementedError
