3_optimization-design-ptn/02_caching/02_semantic_caching.py (106 lines of code) (raw):
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Annotated
from uuid import uuid4
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding
from semantic_kernel.connectors.memory.in_memory.in_memory_store import InMemoryVectorStore
from semantic_kernel.data import (
VectorizedSearchMixin,
VectorSearchOptions,
VectorStore,
VectorStoreRecordCollection,
VectorStoreRecordDataField,
VectorStoreRecordKeyField,
VectorStoreRecordVectorField,
vectorstoremodel,
)
from semantic_kernel.filters import FilterTypes, FunctionInvocationContext, PromptRenderContext
from semantic_kernel.functions import FunctionResult
COLLECTION_NAME = "llm_responses"
RECORD_ID_KEY = "cache_record_id"
# Define a simple data model to store, the prompt, the result, and the prompt embedding.
@vectorstoremodel
@dataclass
class CacheRecord:
prompt: Annotated[str, VectorStoreRecordDataField(embedding_property_name="prompt_embedding")]
result: Annotated[str, VectorStoreRecordDataField(is_full_text_searchable=True)]
prompt_embedding: Annotated[list[float], VectorStoreRecordVectorField(dimensions=1536)] = field(
default_factory=list
)
id: Annotated[str, VectorStoreRecordKeyField] = field(default_factory=lambda: str(uuid4()))
# Define the filters, one for caching the results and one for using the cache.
class PromptCacheFilter:
"""A filter to cache the results of the prompt rendering and function invocation."""
def __init__(
self,
embedding_service: EmbeddingGeneratorBase,
vector_store: VectorStore,
collection_name: str = COLLECTION_NAME,
score_threshold: float = 0.2,
):
self.embedding_service = embedding_service
self.vector_store = vector_store
self.collection: VectorStoreRecordCollection[str, CacheRecord] = vector_store.get_collection(
collection_name, data_model_type=CacheRecord
)
self.score_threshold = score_threshold
async def on_prompt_render(
self, context: PromptRenderContext, next: Callable[[PromptRenderContext], Awaitable[None]]
):
"""Filter to cache the rendered prompt and the result of the function.
It uses the score threshold to determine if the result should be cached.
The direction of the comparison is based on the default distance metric for
the in memory vector store, which is cosine distance, so the closer to 0 the
closer the match.
"""
await next(context)
assert context.rendered_prompt # nosec
prompt_embedding = await self.embedding_service.generate_raw_embeddings([context.rendered_prompt])
await self.collection.create_collection_if_not_exists()
assert isinstance(self.collection, VectorizedSearchMixin) # nosec
results = await self.collection.vectorized_search(
vector=prompt_embedding[0], options=VectorSearchOptions(vector_field_name="prompt_embedding", top=1)
)
async for result in results.results:
if result.score < self.score_threshold:
context.function_result = FunctionResult(
function=context.function.metadata,
value=result.record.result,
rendered_prompt=context.rendered_prompt,
metadata={RECORD_ID_KEY: result.record.id},
)
async def on_function_invocation(
self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
):
"""Filter to store the result in the cache if it is new."""
await next(context)
result = context.result
if result and result.rendered_prompt and RECORD_ID_KEY not in result.metadata:
prompt_embedding = await self.embedding_service.generate_embeddings([result.rendered_prompt])
cache_record = CacheRecord(
prompt=result.rendered_prompt,
result=str(result),
prompt_embedding=prompt_embedding[0],
)
await self.collection.create_collection_if_not_exists()
await self.collection.upsert(cache_record)
async def execute_async(kernel: Kernel, title: str, prompt: str):
"""Helper method to execute and log time."""
print(f"{title}: {prompt}")
start = time.time()
result = await kernel.invoke_prompt(prompt)
elapsed = time.time() - start
print(f"\tElapsed Time: {elapsed:.3f}")
return result
async def main():
# create the kernel and add the chat service and the embedding service
kernel = Kernel()
chat = AzureChatCompletion(service_id="default")
embedding = AzureTextEmbedding(service_id="embedder")
kernel.add_service(chat)
kernel.add_service(embedding)
# create the in-memory vector store
vector_store = InMemoryVectorStore()
# create the cache filter and add the filters to the kernel
cache = PromptCacheFilter(embedding_service=embedding, vector_store=vector_store)
kernel.add_filter(FilterTypes.PROMPT_RENDERING, cache.on_prompt_render)
kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, cache.on_function_invocation)
# Run the sample
print("\nIn-memory cache sample:")
r1 = await execute_async(kernel, "First run", "What's the tallest building in New York?")
print(f"\tResult 1: {r1}")
r2 = await execute_async(kernel, "Second run", "How are you today?")
print(f"\tResult 2: {r2}")
r3 = await execute_async(kernel, "Third run", "What is the highest building in New York City?")
print(f"\tResult 3: {r3}")
if __name__ == "__main__":
asyncio.run(main())