def retrieve()

in experiments/legacy/backend/category.py [0:0]


def retrieve(
    desc: str, 
    image: Optional[str] = None, 
    base64: bool = False,
    num_neighbors: int = config.NUM_NEIGHBORS,
    filters: list[str] = []) -> list[dict]:
    """Returns list of categories based on nearest neighbors.

    This is a 'greedy' retrieval approach that embeds the provided desc and
    (optionally) image and returns the categories corresponding to the closest
    products in embedding space. 

    Args:
        desc: user provided description of product
        image: can be local file path, GCS URI or base64 encoded image
        base64: True indicates image is base64. False (default) will be 
          interpreted as image path (either local or GCS)
        num_neigbhors: number of nearest neighbors to return for EACH embedding
        filters: category prefix to restrict results to

    Returns:
        List of candidates sorted by embedding distance. Each candidate is a
        dict with the following keys:
            id: product ID
            category: category in list form e.g. ['level 1 category', 'level 2 category']
            distance: embedding distance in range [0,1], 0 being the closest match
    """
    res = embeddings.embed(desc,image, base64)
    embeds = [res.text_embedding, res.image_embedding] if res.image_embedding else [res.text_embedding]
    neighbors = nearest_neighbors.get_nn(embeds,filters)
    if not neighbors:
      return []
    ids = [n.id[:-2] for n in neighbors] # last 3 chars are not part of product ID
    categories = join_categories(ids)
    candidates = [{'category':categories[n.id[:-2]],'id':n.id, 'distance':n.distance}
                    for n in neighbors]
    return sorted(candidates, key=lambda d: d['distance'])