google/generativeai/caching.py (181 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import textwrap
from typing import Iterable, Optional
from google.generativeai import protos
from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai.client import get_default_cache_client
from google.protobuf import field_mask_pb2
_USER_ROLE = "user"
_MODEL_ROLE = "model"
class CachedContent:
"""Cached content resource."""
def __init__(self, name):
"""Fetches a `CachedContent` resource.
Identical to `CachedContent.get`.
Args:
name: The resource name referring to the cached content.
"""
client = get_default_cache_client()
if "cachedContents/" not in name:
name = "cachedContents/" + name
request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
self._proto = response
@property
def name(self) -> str:
return self._proto.name
@property
def model(self) -> str:
return self._proto.model
@property
def display_name(self) -> str:
return self._proto.display_name
@property
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
return self._proto.usage_metadata
@property
def create_time(self) -> datetime.datetime:
return self._proto.create_time
@property
def update_time(self) -> datetime.datetime:
return self._proto.update_time
@property
def expire_time(self) -> datetime.datetime:
return self._proto.expire_time
def __str__(self):
return textwrap.dedent(
f"""\
CachedContent(
name='{self.name}',
model='{self.model}',
display_name='{self.display_name}',
usage_metadata={'{'}
'total_token_count': {self.usage_metadata.total_token_count},
{'}'},
create_time={self.create_time},
update_time={self.update_time},
expire_time={self.expire_time}
)"""
)
__repr__ = __str__
@classmethod
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
"""Creates an instance of CachedContent form an object, without calling `get`."""
self = cls.__new__(cls)
self._proto = protos.CachedContent()
self._update(obj)
return self
def _update(self, updates):
"""Updates this instance inplace, does not call the API's `update` method"""
if isinstance(updates, CachedContent):
updates = updates._proto
if not isinstance(updates, dict):
updates = type(updates).to_dict(updates, including_default_value_fields=False)
for key, value in updates.items():
setattr(self._proto, key, value)
@staticmethod
def _prepare_create_request(
model: str,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> protos.CreateCachedContentRequest:
"""Prepares a CreateCachedContentRequest."""
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)
if "/" not in model:
model = "models/" + model
if display_name and len(display_name) > 128:
raise ValueError("`display_name` must be no more than 128 unicode characters.")
if system_instruction:
system_instruction = content_types.to_content(system_instruction)
tools_lib = content_types.to_function_library(tools)
if tools_lib:
tools_lib = tools_lib.to_proto()
if tool_config:
tool_config = content_types.to_tool_config(tool_config)
if contents:
contents = content_types.to_contents(contents)
if not contents[-1].role:
contents[-1].role = _USER_ROLE
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)
cached_content = protos.CachedContent(
model=model,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools_lib,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)
return protos.CreateCachedContentRequest(cached_content=cached_content)
@classmethod
def create(
cls,
model: str,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> CachedContent:
"""Creates `CachedContent` resource.
Args:
model: The name of the `model` to use for cached content creation.
Any `CachedContent` resource can be only used with the
`model` it was created for.
display_name: The user-generated meaningful display name
of the cached content. `display_name` must be no
more than 128 unicode characters.
system_instruction: Developer set system instruction.
contents: Contents to cache.
tools: A list of `Tools` the model may use to generate response.
tool_config: Config to apply to all tools.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.
Returns:
`CachedContent` resource with specified name.
"""
client = get_default_cache_client()
request = cls._prepare_create_request(
model=model,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)
response = client.create_cached_content(request)
result = CachedContent._from_obj(response)
return result
@classmethod
def get(cls, name: str) -> CachedContent:
"""Fetches required `CachedContent` resource.
Args:
name: The resource name referring to the cached content.
Returns:
`CachedContent` resource with specified `name`.
"""
client = get_default_cache_client()
if "cachedContents/" not in name:
name = "cachedContents/" + name
request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
result = CachedContent._from_obj(response)
return result
@classmethod
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
"""Lists `CachedContent` objects associated with the project.
Args:
page_size: The maximum number of permissions to return (per page).
The service may return fewer `CachedContent` objects.
Returns:
A paginated list of `CachedContent` objects.
"""
client = get_default_cache_client()
request = protos.ListCachedContentsRequest(page_size=page_size)
for cached_content in client.list_cached_contents(request):
cached_content = CachedContent._from_obj(cached_content)
yield cached_content
def delete(self) -> None:
"""Deletes `CachedContent` resource."""
client = get_default_cache_client()
request = protos.DeleteCachedContentRequest(name=self.name)
client.delete_cached_content(request)
return
def update(
self,
*,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> None:
"""Updates requested `CachedContent` resource.
Args:
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.
"""
client = get_default_cache_client()
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)
updates = protos.CachedContent(
name=self.name,
ttl=ttl,
expire_time=expire_time,
)
field_mask = field_mask_pb2.FieldMask()
if ttl:
field_mask.paths.append("ttl")
elif expire_time:
field_mask.paths.append("expire_time")
else:
raise ValueError(
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
)
request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
updated_cc = client.update_cached_content(request)
self._update(updated_cc)
return