bot/gemini_model.py (263 lines of code) (raw):

import json import logging import os import time import pathlib from collections import defaultdict, deque from typing import Literal import google.auth import hikari import magic from google import genai from google.genai.types import ( Part, Content, GenerateContentConfig, GenerateContentResponse, Tool, GoogleSearch, ) from hikari import OwnUser import discord_cache VERTEX_TOS = "https://developers.google.com/terms" GEMINI_TOS = "https://ai.google.dev/gemini-api/terms" GEMINI_MODEL_NAME = "gemini-2.0-flash-exp" _MAX_HISTORY_TOKEN_SIZE = 1000000 # 1M to keep things simple, real limit is 1,048,576 _LOGGER = logging.getLogger("bot.gemini") ACCEPTED_MIMES = { "application/pdf", "audio/mpeg", "audio/mp3", "audio/wav", "image/gif", "image/png", "image/jpeg", "image/webp", "text/plain", "video/mov", "video/mpeg", "video/mp4", "video/mpg", "video/avi", "video/wmv", "video/mpegps", "video/flv", } GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", None) GEMINI_PROJECT = os.getenv("GOOGLE_CLOUD_PROJECT", google.auth.default()[1]) if GEMINI_API_KEY is None: _client = genai.Client( # Need to use us-central1 as it's the only region with gemini-2.0-flash-exp model vertexai=True, project=GEMINI_PROJECT, location="us-central1" ) else: _client = genai.Client(api_key=GEMINI_API_KEY) _system_instructions_file = pathlib.Path("system_instructions.txt") if not _system_instructions_file.is_file(): _system_instructions_file = pathlib.Path("default_system_instructions.txt") _gen_config = GenerateContentConfig( temperature=0, candidate_count=1, max_output_tokens=1800, system_instruction=_system_instructions_file.read_text(), tools=[Tool(google_search=GoogleSearch())], # for Gemini 2 response_modalities=["TEXT"], ) class ChatPart: """ ChatPart is used to internally store the history of communication in various Discord channels. We save the chat message as `vertexai.generative_models.Part` object (ready to use in communication with Gemini), `role` tells us if given message was made by our bot ("model") or a user ("user"). The object also stores the token count, so we don't need to re-calculate it. Currently, our code handles only text interactions, however the Part objects can represent images, videos and audio files, which will be useful in the future. """ def __init__( self, chat_part: Part, role: Literal["user", "model"], token_count: int=None ): self.part = chat_part self.role = role self.token_count = token_count or self._count_tokens(chat_part) def __str__(self): return f"<{self.role}: {self.part} [{self.token_count}]>" def __repr__(self): return str(self) @staticmethod def _count_tokens(part: Part) -> int: if hasattr(part, "text") and part.text is not None: return int(len(part.text) * 0.3) if hasattr(part, "inline_data") and part.inline_data is not None: if part.inline_data.mime_type.startswith("image"): return 258 elif part.inline_data.mime_type.startswith("video"): return int(len(part.inline_data.data) / 1000) elif part.inline_data.mime_type.startswith("audio"): return int(len(part.inline_data.data) / 800) # 1000 bytes per token video, 800 bytes per token for audio - those are rough estimates _LOGGER.debug(f"Counting tokens for {part}"[:200]) start = time.time() count = _client.models.count_tokens( model=GEMINI_MODEL_NAME, contents=part ).total_tokens _LOGGER.debug( f"Counted tokens for {part} in {time.time() - start}s with token count: {count}" ) return count @classmethod def _parse_embed(cls, sender: str, embed: hikari.Embed) -> list[Part]: title = embed.title description = embed.description image = embed.image.proxy_url if embed.image else None author = embed.author.name if embed.author else None footer = embed.footer.text if embed.footer else None fields = [{'title': f.name, 'text': f.value} for f in embed.fields] embed_json = {'title': title, 'description': description, 'fields': fields, 'sender': sender, 'author': author, 'footer': footer, 'type': 'embed}'} embed_part = Part.from_text(text=json.dumps(embed_json)) if image: data = discord_cache.get_from_cache(image) embed_img = Part.from_bytes(data=data, mime_type=magic.from_buffer(data, mime=True)) return [embed_part, embed_img] else: return [embed_part] @classmethod def from_user_chat_message(cls, message: hikari.Message) -> list["ChatPart"]: """ Create a user ChatPart object from hikari.Message. Stores the text content of the message as JSON encoded object and assigns the `role` as "user". This method also calculates and saves the token count. """ author = getattr(message.member, "display_name", False) msg = json.dumps( { "author": author or message.author.username, "content": message.content, } ) text_part = Part.from_text(text=msg) parts = [(text_part, cls._count_tokens(text_part))] for e in message.embeds: for embed_part in cls._parse_embed(author, e): parts.append((embed_part, cls._count_tokens(embed_part))) for a in message.attachments: if a.media_type not in ACCEPTED_MIMES: part = Part.from_text( text=f"Here user uploaded a file in unsupported {a.media_type} type." ) else: data = discord_cache.get_from_cache(a.url) part = Part.from_bytes(data=data, mime_type=a.media_type) parts.append((part, cls._count_tokens(part))) return [cls(part, "user", tokens) for part, tokens in parts] @classmethod def from_bot_chat_message(cls, message: hikari.Message) -> list["ChatPart"]: """ Create a model ChatPart object from hikari.Message. Stores the text content of the message and assigns the `role` as "model". This method also calculates and saves the token count. """ part = Part.from_text(text=message.content) tokens = cls._count_tokens(part) parts = [(part, tokens)] for a in message.attachments: data = discord_cache.get_from_cache(a.url) part = Part.from_bytes(data=data, mime_type=a.media_type) parts.append((part, cls._count_tokens(part))) return [cls(part, "model", tokens) for part, tokens in parts] @classmethod def from_ai_reply(cls, response: GenerateContentResponse | Part) -> "ChatPart": """ Create a model ChatPart object from Gemini response. Stores the text content of the message and assigns the `role` as "model". Saves the token count from the model response. """ part = Part.from_text(text=response.text) if isinstance(response, GenerateContentResponse): tokens = response.usage_metadata.candidates_token_count else: tokens = cls._count_tokens(part) return cls(part, "model", tokens) @classmethod def from_raw_part( cls, part: Part, role: Literal["user", "model"] = "model" ) -> "ChatPart": """ Create a model ChatPart object from a raw Part. Stores the whole part and assigns the `role` as "model". Saves the token count using model call. """ return cls(part, role, cls._count_tokens(part)) @classmethod def from_bytes(cls, data: bytes, mime_type: str) -> "ChatPart": """ Create a model ChatPart object to represent attached image. Stores the bytes content and assigns the `role` as "model". Saves the token count by querying Gemini model. """ part = Part.from_bytes(bytes=data, mime_type=mime_type) return cls(part, "model", cls._count_tokens(part)) class ChatHistory: """ Object of this class keeps track of the chat history in a single Discord channel by storing a deque of ChatPart objects. """ def __init__(self): self._history: deque[ChatPart] = deque() async def add_message(self, message: hikari.Message) -> None: """ Create a new ChatPart object and append it to the chat history. """ self._history.extend(ChatPart.from_user_chat_message(message)) async def load_history(self, channel: hikari.GuildTextChannel, bot_id: int) -> None: """ Reads chat history of a given channel and stores it internally. The history will be read until the history exceeds the token limit for the Gemini model. :param channel: Guild channel that will be read. :param bot_id: The ID of the bot that we are running as. Needed to properly recognize responses from previous sessions. """ _LOGGER.info(f"Loading history for: {channel} {type(channel)}") guild = channel.get_guild() member_cache = {} tokens = 0 messages = 0 async for message in channel.fetch_history(): messages += 1 if messages > 50: # To speed up the starting process, just read only the last 50 messages break if message.author.id not in member_cache: member_cache[message.author.id] = guild.get_member(message.author.id) message.member = member_cache[message.author.id] if message.author.id == bot_id: self._history.extendleft( reversed(ChatPart.from_bot_chat_message(message)) ) else: self._history.extendleft( reversed(ChatPart.from_user_chat_message(message)) ) tokens += self._history[0].token_count or 0 if tokens > _MAX_HISTORY_TOKEN_SIZE: break _LOGGER.info(f"History loaded for {channel}.") def _build_content(self) -> (list[Content], int): """ Prepare the whole Content structure to be sent to the AI model, containing the whole chat history so far (or as much as the token limit allows). The Gemini model accepts a sequence of Content objects, each Content object contains one or more Part objects. Content objects have the `role` attribute that tells Gemini who's the author of a given piece of conversation history. The model expects that the sequence of incoming Content objects is a conversation between "model" and "user" - in our case, we combine all user messages into single Content object, with proper attribution, so that Gemini can recognize who said what. Model Content objects are sent as regular text. """ # Buffer keeps tuples of (part, role) buffer = deque() contents = deque() tokens = 0 parts_count = 0 for part in reversed(self._history): parts_count += 1 if buffer and buffer[0][1] != part.role: content = Content(role=buffer[0][1], parts=list(b[0] for b in buffer)) contents.appendleft(content) buffer.clear() buffer.appendleft((part.part, part.role)) tokens += part.token_count or 0 if tokens > _MAX_HISTORY_TOKEN_SIZE: _LOGGER.info("Memory full, will purge now.") break # We fit whole _history in the contents, no need to clear memory if buffer: user_content = Content(role=buffer[0][1], parts=list(b[0] for b in buffer)) contents.appendleft(user_content) # We need to forget the tail of history, so we don't waste memory for _ in range(len(self._history) - parts_count): self._history.popleft() while contents[0].role == "model" or len(contents[0].parts) == 0: # Can't have model start the conversation contents.popleft() return list(contents), tokens async def trigger_answer(self) -> (list[str], list[hikari.Bytes]): """ Uses AI to generate answer to the current chat history. Will handle function calling if the model requests functions to be called. Note: The last message in the chat history has to be from a user. """ if self._history[-1].role == "model": raise RuntimeError( "Last message in chat history needs to be from a user to generate a reply." ) content, tokens = self._build_content() _LOGGER.info(f"Generating answer for estimated {tokens} tokens...") start = time.time() response = await _client.aio.models.generate_content( model=GEMINI_MODEL_NAME, contents=content, config=_gen_config ) _LOGGER.info( f"Generated response for estimated {tokens} tokens in {time.time()-start}s." ) self._history.append(ChatPart.from_ai_reply(response)) return [response.text], [] class GeminiBot: """ Class representing the state of current instance of our Bot. It keeps track of its own identity, Gemini model configuration and chat history for all the channels it interacted with. """ def __init__(self, me: OwnUser): self.me = me self.memory = defaultdict(ChatHistory) async def handle_message(self, event: hikari.GuildMessageCreateEvent) -> None: """ Handle an incoming message. This method will save the message in bot's chat history and generate a reply if one is needed (bot was mentioned in the message). """ # Do not respond to bots nor webhooks pinging us, only user accounts if event.author_id == self.me.id: return message = event.message channel = await message.fetch_channel() if message.channel_id not in self.memory: chat_history = ChatHistory() await chat_history.load_history(channel, self.me.id) self.memory[message.channel_id] = chat_history else: # Loading history would catch this message anyway await self.memory[message.channel_id].add_message(message) if self.me.id not in event.message.user_mentions_ids: # Reply only when mentioned in the message. return # The bot has been pinged, we need to reply await channel.trigger_typing() try: text_responses, attachments = await self.memory[ message.channel_id ].trigger_answer() except Exception as e: await event.message.respond( "Sorry, there was an error processing an answer for you :(" ) raise e for text_response in text_responses[:-1]: await event.message.respond(text_response[:2000]) await event.message.respond(text_responses[-1][:2000], attachments=attachments)