# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import Dict, List, Optional, TypeVar

import google.auth
import google.auth.transport.requests
import redis
import requests
from google.cloud.aiplatform.matching_engine import matching_engine_index_endpoint
from services.multimodal_embedding_client import MultimodalEmbeddingPredictionClient

import storage_helper
import tracer_helper
from services.match_service import (
    CodeInfo,
    Item,
    MatchResult,
    VertexAIMatchingEngineMatchService,
)

tracer = tracer_helper.get_tracer(__name__)

DESTINATION_BLOB_NAME = "multimodal_text_to_image"


def get_access_token() -> str:
    # Get default access token
    creds, _ = google.auth.default()
    # creds.valid is False, and creds.token is None
    # Need to refresh credentials to populate those
    auth_req = google.auth.transport.requests.Request()
    creds.refresh(auth_req)
    access_token = creds.token

    if access_token is None or len(access_token) == 0:
        raise RuntimeError("No access token found")

    return access_token


T = TypeVar("T")


class MultimodalTextToImageMatchService(VertexAIMatchingEngineMatchService[T]):
    @property
    def id(self) -> str:
        return self._id

    @property
    def name(self) -> str:
        """Name for this service that is shown on the frontend."""
        return self._name

    @property
    def description(self) -> str:
        """Description for this service that is shown on the frontend."""
        return self._description

    @property
    def allows_text_input(self) -> bool:
        """If true, this service allows text input."""
        return self._allows_text_input

    @property
    def allows_image_input(self) -> bool:
        """If true, this service allows image input."""
        return self._allows_image_input

    @property
    def code_info(self) -> Optional[CodeInfo]:
        """Info about code used to generate index."""
        return self._code_info

    def __init__(
        self,
        id: str,
        name: str,
        description: str,
        allows_text_input: bool,
        allows_image_input: bool,
        index_endpoint_name: str,
        deployed_index_id: str,
        project_id: str,
        gcs_bucket: str,
        is_public_index_endpoint: bool,
        prompts_texts_file: Optional[str] = None,
        prompt_images_file: Optional[str] = None,
        code_info: Optional[CodeInfo] = None,
    ) -> None:
        self._id = id
        self._name = name
        self._description = description
        self._code_info = code_info
        self.project_id = project_id
        self._allows_text_input = allows_text_input
        self._allows_image_input = allows_image_input
        self.gcs_bucket = gcs_bucket

        if prompts_texts_file:
            with open(prompts_texts_file, "r") as f:
                prompts = f.readlines()
                self.prompt_texts = [prompt.strip() for prompt in prompts]
        else:
            self.prompt_texts = []

        if prompt_images_file:
            with open(prompt_images_file, "r") as f:
                prompt_images = f.readlines()
                self.prompt_images = [prompt.strip() for prompt in prompt_images]
        else:
            self.prompt_images = []

        self.index_endpoint = (
            matching_engine_index_endpoint.MatchingEngineIndexEndpoint(
                index_endpoint_name=index_endpoint_name
            )
        )
        self.deployed_index_id = deployed_index_id
        self.client = MultimodalEmbeddingPredictionClient(project_id=self.project_id)
        self.is_public_index_endpoint = is_public_index_endpoint

    @tracer.start_as_current_span("get_suggestions")
    def get_suggestions(self, num_items: int = 60) -> List[Item]:
        """Get suggestions for search queries."""
        text_prompts = (
            [Item(id=word, text=word, image=None) for word in self.prompt_texts]
            if self.allows_text_input
            else []
        )
        image_prompts = (
            [
                Item(id=image_url, text="", image=image_url)
                for image_url in self.prompt_images
            ]
            if self.allows_image_input
            else []
        )

        prompts = text_prompts + image_prompts

        return random.sample(
            prompts,
            min(num_items, len(prompts)),
        )

    def encode_image_to_embeddings(self, image_uri: str) -> List[float]:
        try:
            return self.client.get_embedding(
                text=None, image_file=image_uri
            ).image_embedding
        except Exception as ex:
            raise RuntimeError("Error getting embedding.")

    def encode_text_to_embeddings(self, text: str) -> List[float]:
        try:
            return self.client.get_embedding(text=text, image_file=None).text_embedding
        except Exception as ex:
            raise RuntimeError("Error getting embedding.")

    @tracer.start_as_current_span("convert_text_to_embeddings")
    def convert_text_to_embeddings(self, target: str) -> Optional[List[float]]:
        return self.encode_text_to_embeddings(text=target)

    @tracer.start_as_current_span("convert_image_to_embeddings")
    def convert_image_to_embeddings(
        self, image_file_local_path: str
    ) -> Optional[List[float]]:
        """Convert a given item to an embedding representation."""
        # Upload image file
        image_uri = storage_helper.upload_blob(
            source_file_name=image_file_local_path,
            bucket_name=self.gcs_bucket,
            destination_blob_name=DESTINATION_BLOB_NAME,
        )

        # Convert GCS path to HTTP path
        image_uri_http = f"https://storage.googleapis.com/{image_uri[5:]}"

        return self.encode_image_to_embeddings(image_uri=image_uri_http)

    @tracer.start_as_current_span("convert_image_to_embeddings_remote")
    def convert_image_to_embeddings_remote(
        self, image_file_remote_path: str
    ) -> Optional[List[float]]:
        """Convert a given item to an embedding representation."""
        return self.encode_image_to_embeddings(
            image_uri=image_file_remote_path,
        )


class MercariTextToImageMatchService(MultimodalTextToImageMatchService[Dict[str, str]]):
    def __init__(
        self,
        id: str,
        name: str,
        description: str,
        allows_text_input: bool,
        allows_image_input: bool,
        index_endpoint_name: str,
        deployed_index_id: str,
        project_id: str,
        redis_host: str,  # Redis host to get data about a match id
        redis_port: int,  # Redis port to get data about a match id
        gcs_bucket: str,
        is_public_index_endpoint: bool,
        prompts_texts_file: Optional[str] = None,
        prompt_images_file: Optional[str] = None,
        code_info: Optional[CodeInfo] = None,
    ) -> None:
        super().__init__(
            id=id,
            name=name,
            description=description,
            code_info=code_info,
            project_id=project_id,
            allows_text_input=allows_text_input,
            allows_image_input=allows_image_input,
            gcs_bucket=gcs_bucket,
            prompts_texts_file=prompts_texts_file,
            prompt_images_file=prompt_images_file,
            index_endpoint_name=index_endpoint_name,
            deployed_index_id=deployed_index_id,
            is_public_index_endpoint=is_public_index_endpoint,
        )
        self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port)

    @tracer.start_as_current_span("get_by_id")
    def get_by_id(self, id: str) -> Optional[Dict[str, str]]:
        """Get an item by id."""
        retrieved = self.redis_client.hgetall(str(id))

        if retrieved is not None:
            # Convert the byte strings to regular strings
            return {key.decode(): value.decode() for key, value in retrieved.items()}
        else:
            return None

    @tracer.start_as_current_span("convert_match_neighbors_to_result")
    def convert_match_neighbors_to_result(
        self, matches: List[matching_engine_index_endpoint.MatchNeighbor]
    ) -> List[Optional[MatchResult]]:
        items = [self.get_by_id(match.id) for match in matches]
        return [
            MatchResult(
                title=item["name"],
                description=item["description"],
                distance=max(0, 1 - match.distance),
                url=item["url"],
                image=item["img_url"],
            )
            if item is not None
            else None
            for item, match in zip(items, matches)
        ]


class RoomsTextToImageMatchService(MultimodalTextToImageMatchService[str]):
    @tracer.start_as_current_span("get_by_id")
    def get_by_id(self, id: str) -> Optional[str]:
        """Get an item by id."""

        return id

    @tracer.start_as_current_span("convert_match_neighbors_to_result")
    def convert_match_neighbors_to_result(
        self, matches: List[matching_engine_index_endpoint.MatchNeighbor]
    ) -> List[Optional[MatchResult]]:
        items = [self.get_by_id(match.id) for match in matches]
        return [
            MatchResult(
                title=None,
                description=None,
                distance=max(0, 1 - match.distance),
                url=None,
                image=f"https://storage.googleapis.com/ai-demos-us-central1/interior_images/mit_indoor/{item}",
            )
            if item is not None
            else None
            for item, match in zip(items, matches)
        ]
