import binascii
import hashlib
import logging
import math
import os
from tempfile import NamedTemporaryFile
from typing import Any

from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session

from securedrop_client.api_jobs.base import SingleObjectApiJob
from securedrop_client.crypto import CryptoError, GpgHelper
from securedrop_client.db import DownloadError, DownloadErrorCodes, File, Message, Reply
from securedrop_client.sdk import API, BaseError
from securedrop_client.sdk import Reply as SdkReply
from securedrop_client.sdk import Submission as SdkSubmission
from securedrop_client.storage import (
    mark_as_decrypted,
    mark_as_downloaded,
    set_message_or_reply_content,
)
from securedrop_client.utils import safe_move

logger = logging.getLogger(__name__)


class DownloadException(Exception):
    def __init__(
        self, message: str, object_type: type[Reply] | type[Message] | type[File], uuid: str
    ):
        super().__init__(message)
        self.object_type = object_type
        self.uuid = uuid


class DownloadChecksumMismatchException(DownloadException):
    """
    Raised when a download's hash does not match the SecureDrop server's.
    """


class DownloadDecryptionException(DownloadException):
    """
    Raised when an error occurs during decryption of a download.
    """


class DownloadJob(SingleObjectApiJob):
    """
    Download and decrypt a file that contains either a message, reply, or file submission.
    """

    CHUNK_SIZE = 4096

    def __init__(self, data_dir: str, uuid: str) -> None:
        super().__init__(uuid)
        self.data_dir = data_dir

    def _get_realistic_timeout(self, size_in_bytes: int) -> int:
        """
        Return a realistic timeout in seconds based on the size of the download.

        This simply scales the timeouts per file so that it increases as the file size increases.

        Note that:

        * The size of the file provided by server is in bytes (the server computes it using
          os.stat.ST_SIZE).

        * The following times are reasonable estimations for how long it should take to fetch a
          file over Tor according to Tor metrics given a recent three month period in 2022-2023:

          50 KiB  (51200 bytes)   =  6   seconds  (8,533 bytes/second)
          1  MiB  (1049000 bytes) =  15 seconds  (~ 69,905 bytes/second)

          For more information, see:
          https://metrics.torproject.org/torperf.html?start=2022-12-06&end=2023-03-06&server=onion

        * As you might expect, this method returns timeouts that are larger than the expected
          download time, which is why the rates below are slower than what you see above with the
          Tor metrics, e.g. instead of setting TIMEOUT_BYTES_PER_SECOND to 69905 bytes/second, we
          set it to 50000 bytes/second.

        * Minimum timeout allowed is 25 seconds
        """
        TIMEOUT_BYTES_PER_SECOND = 50_000.0
        TIMEOUT_ADJUSTMENT_FACTOR = 1.5
        TIMEOUT_BASE = 25
        timeout = math.ceil((size_in_bytes / TIMEOUT_BYTES_PER_SECOND) * TIMEOUT_ADJUSTMENT_FACTOR)
        return timeout + TIMEOUT_BASE

    def call_download_api(self, api: API, db_object: File | Message | Reply) -> tuple[str, str]:
        """
        Method for making the actual API call to download the file and handling the result.

        This MUST return the (etag, filepath) tuple response from the server and MUST raise an
        exception if and only if the download fails.
        """
        raise NotImplementedError

    def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
        """
        Method for decrypting the file and storing the plaintext result.

        Returns the original filename.

        This MUST raise an exception if and only if the decryption fails.
        """
        raise NotImplementedError

    def get_db_object(self, session: Session) -> File | Message | Reply:
        """
        Get the database object associated with this job; may raise
        DownloadException if not found
        """
        raise NotImplementedError

    def call_api(self, api_client: API, session: Session) -> Any:
        """
        Override ApiJob.

        Download and decrypt the file associated with the database object.
        """
        db_object = self.get_db_object(session)

        if db_object.is_decrypted:
            logger.debug(f"item with uuid {self.uuid} already decrypted, returning")
            return db_object.uuid

        if db_object.is_downloaded:
            logger.debug(f"item with uuid {self.uuid} already downloaded, now decrypting")
            self._decrypt(db_object.location(self.data_dir), db_object, session)
            return db_object.uuid

        destination = self._download(api_client, db_object, session)
        self._decrypt(destination, db_object, session)
        return db_object.uuid

    def _download(self, api: API, db_object: File | Message | Reply, session: Session) -> str:
        """
        Download the encrypted file. Check file integrity and move it to the data directory before
        marking it as downloaded.

        Note: On Qubes OS, files are downloaded to /home/user/Downloads
        """
        try:
            etag, download_path = self.call_download_api(api, db_object)

            if not self._check_file_integrity(etag, download_path):
                download_error = (
                    session.query(DownloadError)
                    .filter_by(name=DownloadErrorCodes.CHECKSUM_ERROR.name)
                    .one()
                )
                db_object.download_error = download_error
                session.commit()
                exception = DownloadChecksumMismatchException(
                    "Downloaded file had an invalid checksum.", type(db_object), db_object.uuid
                )
                raise exception

            destination = db_object.location(self.data_dir)
            safe_move(download_path, destination, self.data_dir)
            db_object.download_error = None
            mark_as_downloaded(type(db_object), db_object.uuid, session)
            logger.info(f"File downloaded to {destination}")
            return destination
        except (ValueError, FileNotFoundError, RuntimeError, BaseError) as e:
            logger.error("Download failed")
            logger.debug(f"Download failed: {e}")
            raise DownloadDecryptionException(
                f"Failed to download {db_object.uuid}", type(db_object), db_object.uuid
            ) from e

    def _decrypt(self, filepath: str, db_object: File | Message | Reply, session: Session) -> None:
        """
        Decrypt the file located at the given filepath and mark it as decrypted.
        """
        try:
            original_filename = self.call_decrypt(filepath, session)
            db_object.download_error = None
            mark_as_decrypted(
                type(db_object), db_object.uuid, session, original_filename=original_filename
            )
            logger.info(f"File decrypted to {os.path.dirname(filepath)}")
        except CryptoError as e:
            logger.error("Decryption failed")
            logger.debug(f"Decryption failed: {e}")
            mark_as_decrypted(type(db_object), db_object.uuid, session, is_decrypted=False)
            download_error = (
                session.query(DownloadError)
                .filter_by(name=DownloadErrorCodes.DECRYPTION_ERROR.name)
                .one()
            )
            db_object.download_error = download_error
            session.commit()
            raise DownloadDecryptionException(
                f"Failed to decrypt file: {os.path.basename(filepath)}",
                type(db_object),
                db_object.uuid,
            ) from e

    @classmethod
    def _check_file_integrity(cls, etag: str, file_path: str) -> bool:
        """
        Return True if file checksum is valid or unknown, otherwise return False.
        """
        if not etag:
            logger.debug(f"No ETag. Skipping integrity check for file at {file_path}")
            return True

        alg, checksum = etag.split(":")

        if alg == "sha256":
            hasher = hashlib.sha256()
        else:
            logger.debug(
                f"Unknown hash algorithm ({alg}). Skipping integrity check for file at {file_path}"
            )
            return True

        with open(file_path, "rb") as f:
            while True:
                read_bytes = f.read(cls.CHUNK_SIZE)
                if not read_bytes:
                    break
                hasher.update(read_bytes)

        calculated_checksum = binascii.hexlify(hasher.digest()).decode("utf-8")
        return calculated_checksum == checksum


class ReplyDownloadJob(DownloadJob):
    """
    Download and decrypt a reply from a source.
    """

    def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
        super().__init__(data_dir, uuid)
        self.gpg = gpg

    def get_db_object(self, session: Session) -> Reply:
        """
        Override DownloadJob.
        """
        try:
            return session.query(Reply).filter_by(uuid=self.uuid).one()
        except NoResultFound:
            raise DownloadException("Reply not found in database", Reply, self.uuid)

    def call_download_api(self, api: API, db_object: Reply) -> tuple[str, str]:
        """
        Override DownloadJob.
        """
        sdk_object = SdkReply(uuid=db_object.uuid, filename=db_object.filename)
        sdk_object.source_uuid = db_object.source.uuid

        # TODO: Once https://github.com/freedomofpress/securedrop-sdk/issues/108 is implemented, we
        # will want to pass the default request timeout to download_reply instead of setting it on
        # the api object directly.
        api.default_request_timeout = 20
        return api.download_reply(sdk_object)

    def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
        """
        Override DownloadJob.

        Decrypt the file located at the given filepath and store its plaintext content in the local
        database.

        The file containing the plaintext should be deleted once the content is stored in the db.

        The return value is an empty string; replies have no original filename.
        """
        with NamedTemporaryFile("w+") as plaintext_file:
            try:
                self.gpg.decrypt_submission_or_reply(filepath, plaintext_file.name, is_doc=False)
                set_message_or_reply_content(
                    model_type=Reply, uuid=self.uuid, session=session, content=plaintext_file.read()
                )
            finally:
                try:
                    os.rmdir(os.path.dirname(filepath))
                except OSError:
                    msg = f"Could not delete decryption directory: {os.path.dirname(filepath)}"
                    logger.debug(msg)

        return ""


class MessageDownloadJob(DownloadJob):
    """
    Download and decrypt a message from a source.
    """

    def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
        super().__init__(data_dir, uuid)
        self.uuid = uuid
        self.gpg = gpg

    def get_db_object(self, session: Session) -> Message:
        """
        Override DownloadJob.
        """
        try:
            return session.query(Message).filter_by(uuid=self.uuid).one()
        except NoResultFound:
            raise DownloadException("Message not found in database", Message, self.uuid)

    def call_download_api(self, api: API, db_object: Message) -> tuple[str, str]:
        """
        Override DownloadJob.
        """
        sdk_object = SdkSubmission(uuid=db_object.uuid)
        sdk_object.source_uuid = db_object.source.uuid
        sdk_object.filename = db_object.filename
        return api.download_submission(
            sdk_object, timeout=self._get_realistic_timeout(db_object.size)
        )

    def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
        """
        Override DownloadJob.

        Decrypt the file located at the given filepath and store its plaintext content in the local
        database.

        The file containing the plaintext should be deleted once the content is stored in the db.

        The return value is an empty string; messages have no original filename.
        """
        with NamedTemporaryFile("w+") as plaintext_file:
            try:
                self.gpg.decrypt_submission_or_reply(filepath, plaintext_file.name, is_doc=False)
                set_message_or_reply_content(
                    model_type=Message,
                    uuid=self.uuid,
                    session=session,
                    content=plaintext_file.read(),
                )
            finally:
                try:
                    os.rmdir(os.path.dirname(filepath))
                except OSError:
                    msg = f"Could not delete decryption directory: {os.path.dirname(filepath)}"
                    logger.debug(msg)

        return ""


class FileDownloadJob(DownloadJob):
    """
    Download and decrypt a file from a source.
    """

    def __init__(self, uuid: str, data_dir: str, gpg: GpgHelper) -> None:
        super().__init__(data_dir, uuid)
        self.gpg = gpg

    def get_db_object(self, session: Session) -> File:
        """
        Override DownloadJob.
        """
        try:
            return session.query(File).filter_by(uuid=self.uuid).one()
        except NoResultFound:
            raise DownloadException("File not found in database", File, self.uuid)

    def call_download_api(self, api: API, db_object: File) -> tuple[str, str]:
        """
        Override DownloadJob.
        """
        sdk_object = SdkSubmission(uuid=db_object.uuid)
        sdk_object.source_uuid = db_object.source.uuid
        sdk_object.filename = db_object.filename
        return api.download_submission(
            sdk_object, timeout=self._get_realistic_timeout(db_object.size)
        )

    def call_decrypt(self, filepath: str, session: Session | None = None) -> str:
        """
        Override DownloadJob.

        Decrypt the file located at the given filepath and store its plaintext content in a file on
        the filesystem.

        The file storing the plaintext should have the same name as the downloaded file but without
        the file extensions, e.g. 1-impractical_thing-doc.gz.gpg -> 1-impractical_thing-doc
        """
        fn_no_ext, _ = os.path.splitext(os.path.splitext(os.path.basename(filepath))[0])
        plaintext_filepath = os.path.join(os.path.dirname(filepath), fn_no_ext)
        return self.gpg.decrypt_submission_or_reply(filepath, plaintext_filepath, is_doc=True)
