in experiments/legacy/backend/category.py [0:0]
def retrieve(
desc: str,
image: Optional[str] = None,
base64: bool = False,
num_neighbors: int = config.NUM_NEIGHBORS,
filters: list[str] = []) -> list[dict]:
"""Returns list of categories based on nearest neighbors.
This is a 'greedy' retrieval approach that embeds the provided desc and
(optionally) image and returns the categories corresponding to the closest
products in embedding space.
Args:
desc: user provided description of product
image: can be local file path, GCS URI or base64 encoded image
base64: True indicates image is base64. False (default) will be
interpreted as image path (either local or GCS)
num_neigbhors: number of nearest neighbors to return for EACH embedding
filters: category prefix to restrict results to
Returns:
List of candidates sorted by embedding distance. Each candidate is a
dict with the following keys:
id: product ID
category: category in list form e.g. ['level 1 category', 'level 2 category']
distance: embedding distance in range [0,1], 0 being the closest match
"""
res = embeddings.embed(desc,image, base64)
embeds = [res.text_embedding, res.image_embedding] if res.image_embedding else [res.text_embedding]
neighbors = nearest_neighbors.get_nn(embeds,filters)
if not neighbors:
return []
ids = [n.id[:-2] for n in neighbors] # last 3 chars are not part of product ID
categories = join_categories(ids)
candidates = [{'category':categories[n.id[:-2]],'id':n.id, 'distance':n.distance}
for n in neighbors]
return sorted(candidates, key=lambda d: d['distance'])