experiments/legacy/backend/attributes.py (92 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functions related to attribute generation.""" import json import logging from typing import Union, Optional import config import embeddings import nearest_neighbors import utils bq_client = utils.get_bq_client() llm = utils.get_llm() def join_attributes_desc( ids: list[str]) -> dict[str:dict]: """Gets the attributes and description for given product IDs. Args: ids: The product IDs to get the attributes for. Returns dict mapping product IDs to attributes and descriptions. Each ID will map to a dict with the following keys: attributes: e.g. {'color':'green', 'pattern': striped} description: e.g. 'This is a description' """ query = f""" SELECT {config.COLUMN_ID}, {config.COLUMN_ATTRIBUTES}, {config.COLUMN_DESCRIPTION} FROM `{config.PRODUCT_REFERENCE_TABLE}` WHERE {config.COLUMN_ID} IN {str(ids).replace('[','(').replace(']',')')} """ query_job = bq_client.query(query) rows = query_job.result() attributes = {} for row in rows: attributes[row[config.COLUMN_ID]] = {} attributes[row[config.COLUMN_ID]]['attributes'] = json.loads(row[config.COLUMN_ATTRIBUTES]) attributes[row[config.COLUMN_ID]]['description'] = row[config.COLUMN_DESCRIPTION] return attributes def retrieve( desc: str, category: Optional[str] = None, image: Optional[str] = None, base64: bool = False, num_neighbors: int = config.NUM_NEIGHBORS, filters: list[str] = []) -> list[dict]: """Returns list of attributes based on nearest neighbors. Embeds the provided desc and (optionally) image and returns the attributes corresponding to the closest products in embedding space. Args: desc: user provided description of product category: category of the product image: can be local file path, GCS URI or base64 encoded image base64: True indicates image is base64. False (default) will be interpreted as image path (either local or GCS) num_neigbhors: number of nearest neighbors to return for EACH embedding filters: category prefix to restrict results to Returns: List of candidates sorted by embedding distance. Each candidate is a dict with the following keys: id: product ID attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'} description: string describing product distance: embedding distance in range [0,1], 0 being the closest match """ res = embeddings.embed(desc,image, base64) embeds = [res.text_embedding, res.image_embedding] if res.image_embedding else [res.text_embedding] neighbors = nearest_neighbors.get_nn(embeds,filters) if not neighbors: return [] ids = [n.id[:-2] for n in neighbors] # last 3 chars are not part of product ID attributes_desc = join_attributes_desc(ids) candidates = [ {'attributes':attributes_desc[n.id[:-2]]['attributes'], 'description':attributes_desc[n.id[:-2]]['description'], 'id':n.id, 'distance':n.distance} for n in neighbors] return sorted(candidates, key=lambda d: d['distance']) def generate_prompt(desc: str, candidates: list[dict]) -> str: """Populate LLM prompt template. Args: desc: product description candidates: list of dicts with the following keys: attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'} description: string describing product Returns: prompt to feed to LLM """ examples = '' for candidate in candidates: examples += 'Description: ' + candidate['description']+'\n' examples += 'Attributes:\n' +'|'.join([k+':'+v for k,v in candidate['attributes'].items()])+'\n\n' prompt = f""" Here are examples of Product Descriptions followed by Attributes: {examples} INSTRUCTIONS: Generate attributes based on the description below. Each attribute should be a key:value pair. Do not write any values that contain "NA" on the list. Examples "Material: NA" or "Type: NA" Use a pipe separator "|" to separate attributes. Description: {desc} Attributes: """ return prompt def parse_answer(ans: str) -> dict[str,str]: """Translate LLM response into dict. Args: ans: '|' separated key value pairs e.g. 'color:red|size:large' Returns: ans as a dictionary """ d = {} for a in ans.split('|'): k,v = a.split(':') d[k.strip()]=v.strip() return d def generate_attributes( desc: str, candidates: list[dict] ) -> dict[str,str]: """Use an LLM to determine attributes given nearest neighbor candidates Args: desc: product description candidates: list of dicts with the following keys: attributes: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'} description: string describing product Returns: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'} """ prompt = generate_prompt(desc, candidates) llm_parameters = { "max_output_tokens": 256, "temperature": 0.0, } response = llm.predict( prompt, **llm_parameters ) res = response.text if not res: raise ValueError('ERROR: No LLM response returned. This seems to be an intermittent bug') try: formatted_res = parse_answer(res) except Exception as e: logging.error(e) raise ValueError(f'LLM Response: {res} is not in the expected format') return formatted_res def retrieve_and_generate_attributes( desc: str, category: Optional[str] = None, image: Optional[str] = None, base64: bool = False, num_neighbors: int = config.NUM_NEIGHBORS, filters: list[str] = [] ) -> dict[str,str]: """RAG approach to generating product attributes. Since LLM answers are not always well formatted, if we fail to parse the LLM answer we fallback to a greedy retrieval approach. Args: desc: user provided description of product category: category of the product image: can be local file path, GCS URI or base64 encoded image base64: True indicates image is base64. False (default) will be interpreted as image path (either local or GCS) num_neigbhors: number of nearest neighbors to return for EACH embedding filters: category prefix to restrict results to Returns: attributes in dict form e.g. {'color':'green', 'pattern': 'striped'} """ candidates = retrieve(desc, category, image, base64, num_neighbors, filters) if filters and not candidates: return {'error':'ERROR: no existing products match that category'} try: return generate_attributes(desc, candidates) except ValueError as e: logging.error(e) logging.error('Falling back to greedy approach') return candidates[0]['attributes']