# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions related to attribute generation."""
import json
import logging
from typing import Union, Optional

import config
import embeddings
import nearest_neighbors
import utils

bq_client = utils.get_bq_client()
llm = utils.get_llm()

def join_attributes_desc(
    ids: list[str]) -> dict[str:dict]:
    """Gets the attributes and description for given product IDs.

    Args:
        ids: The product IDs to get the attributes for.

    Returns
        dict mapping product IDs to attributes and descriptions. Each ID will
        map to a dict with the following keys:
            attributes: e.g. {'color':'green', 'pattern': striped}
            description: e.g. 'This is a description'
    """
    query = f"""
    SELECT
        {config.COLUMN_ID},
        {config.COLUMN_ATTRIBUTES},
        {config.COLUMN_DESCRIPTION}
    FROM
        `{config.PRODUCT_REFERENCE_TABLE}`
    WHERE
        {config.COLUMN_ID} IN {str(ids).replace('[','(').replace(']',')')}
    """
    query_job = bq_client.query(query)
    rows = query_job.result()
    attributes = {}
    for row in rows:
        attributes[row[config.COLUMN_ID]] = {}
        attributes[row[config.COLUMN_ID]]['attributes'] = json.loads(row[config.COLUMN_ATTRIBUTES])
        attributes[row[config.COLUMN_ID]]['description'] = row[config.COLUMN_DESCRIPTION]
    return attributes

def retrieve(
    desc: str, 
    category: Optional[str] = None,
    image: Optional[str] = None, 
    base64: bool = False,
    num_neighbors: int = config.NUM_NEIGHBORS,
    filters: list[str] = []) -> list[dict]:
    """Returns list of attributes based on nearest neighbors.

    Embeds the provided desc and (optionally) image and returns the attributes
    corresponding to the closest products in embedding space. 
    
    Args:
        desc: user provided description of product
        category: category of the product
        image: can be local file path, GCS URI or base64 encoded image
        base64: True indicates image is base64. False (default) will be 
          interpreted as image path (either local or GCS)
        num_neigbhors: number of nearest neighbors to return for EACH embedding
        filters: category prefix to restrict results to

    Returns:
        List of candidates sorted by embedding distance. Each candidate is a
        dict with the following keys:
            id: product ID
            attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'}
            description: string describing product
            distance: embedding distance in range [0,1], 0 being the closest match
    """
    res = embeddings.embed(desc,image, base64)
    embeds = [res.text_embedding, res.image_embedding] if res.image_embedding else [res.text_embedding]
    neighbors = nearest_neighbors.get_nn(embeds,filters)
    if not neighbors:
      return []
    ids = [n.id[:-2] for n in neighbors] # last 3 chars are not part of product ID
    attributes_desc = join_attributes_desc(ids)
    candidates = [
        {'attributes':attributes_desc[n.id[:-2]]['attributes'],
        'description':attributes_desc[n.id[:-2]]['description'],
        'id':n.id, 
        'distance':n.distance} for n in neighbors]
    return sorted(candidates, key=lambda d: d['distance'])

def generate_prompt(desc: str, candidates: list[dict]) -> str:
    """Populate LLM prompt template.

    Args:
        desc: product description
        candidates: list of dicts with the following keys:
            attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'}
            description: string describing product

    Returns: prompt to feed to LLM
    """
    examples = ''
    for candidate in candidates:
        examples += 'Description: ' + candidate['description']+'\n'
        examples += 'Attributes:\n' +'|'.join([k+':'+v for k,v in candidate['attributes'].items()])+'\n\n'

    prompt = f"""
Here are examples of Product Descriptions followed by Attributes:

{examples}

INSTRUCTIONS:
Generate attributes based on the description below.
Each attribute should be a key:value pair.
Do not write any values that contain "NA" on the list. Examples "Material: NA" or "Type: NA"
Use a pipe separator "|" to separate attributes.

Description: {desc}
Attributes:
    """
    return prompt

def parse_answer(ans: str) -> dict[str,str]:
    """Translate LLM response into dict.

    Args:
        ans: '|' separated key value pairs e.g. 'color:red|size:large'
    Returns:
        ans as a dictionary
    """
    d = {}
    for a in ans.split('|'):
        k,v = a.split(':')
        d[k.strip()]=v.strip()
    return d

def generate_attributes(
    desc: str,
    candidates: list[dict]
) -> dict[str,str]:
    """Use an LLM to determine attributes given nearest neighbor candidates

    Args:
        desc: product description
        candidates: list of dicts with the following keys:
            attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'}
            description: string describing product

    Returns: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'}
    """
    prompt = generate_prompt(desc, candidates)
    llm_parameters = {
        "max_output_tokens": 256,
        "temperature": 0.0,
    }
    response = llm.predict(
        prompt,
        **llm_parameters
    )
    res = response.text
    if not res:
        raise ValueError('ERROR: No LLM response returned. This seems to be an intermittent bug')
    try:
        formatted_res = parse_answer(res)
    except Exception as e:
        logging.error(e)
        raise ValueError(f'LLM Response: {res} is not in the expected format')
    return formatted_res

def retrieve_and_generate_attributes(
    desc: str,
    category: Optional[str] = None,
    image: Optional[str] = None,
    base64: bool = False,
    num_neighbors: int = config.NUM_NEIGHBORS,
    filters: list[str] = []
) -> dict[str,str]:
    """RAG approach to generating product attributes.

    Since LLM answers are not always well formatted, if we fail to parse the
    LLM answer we fallback to a greedy retrieval approach.

    Args:
        desc: user provided description of product
        category: category of the product
        image: can be local file path, GCS URI or base64 encoded image
        base64: True indicates image is base64. False (default) will be 
          interpreted as image path (either local or GCS)
        num_neigbhors: number of nearest neighbors to return for EACH embedding
        filters: category prefix to restrict results to

    Returns: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'}
    """
    candidates = retrieve(desc, category, image, base64, num_neighbors, filters)
    if filters and not candidates:
        return {'error':'ERROR: no existing products match that category'}
    try:
        return generate_attributes(desc, candidates)
    except ValueError as e:
        logging.error(e)
        logging.error('Falling back to greedy approach')
        return candidates[0]['attributes']

    