llm_demo/orchestrator/orchestrator.py (56 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from typing import Any, Optional class classproperty: def __init__(self, func): self.fget = func def __get__(self, instance, owner): return self.fget(owner) class BaseOrchestrator(ABC): MODEL = "gemini-pro" @classproperty @abstractmethod def kind(cls): pass @abstractmethod def user_session_exist(self, uuid: str) -> bool: """Check if user session exist.""" raise NotImplementedError("Subclass should implement this!") @abstractmethod async def user_session_create(self, session: dict[str, Any]): """Create user session for orchestrator.""" raise NotImplementedError("Subclass should implement this!") @abstractmethod async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]: """Invoke user session and return a response from llm orchestrator.""" raise NotImplementedError("Subclass should implement this!") @abstractmethod def user_session_reset(self, session: dict[str, Any], uuid: str): """Reset and clear history from user session.""" raise NotImplementedError("Subclass should implement this!") @abstractmethod def get_user_session(self, uuid: str) -> Any: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def user_session_insert_ticket(self, uuid: str, params: str) -> Any: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def user_session_decline_ticket(self, uuid: str) -> Optional[dict[str, Any]]: raise NotImplementedError("Subclass should implement this!") @abstractmethod async def user_session_signout(self, uuid: str): """Sign out from user session. Clear and restart session.""" raise NotImplementedError("Subclass should implement this!") def set_user_session_header(self, uuid: str, user_id_token: str): user_session = self.get_user_session(uuid) user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}" def get_user_id_token(self, uuid: str) -> Optional[str]: if self.user_session_exist(uuid): user_session = self.get_user_session(uuid) if user_session.client and "User-Id-Token" in user_session.client.headers: token = user_session.client.headers["User-Id-Token"] parts = str(token).split(" ") if len(parts) != 2 or parts[0] != "Bearer": raise Exception("Invalid ID token") return parts[1] return None def createOrchestrator(orchestration_type: str) -> "BaseOrchestrator": for cls in BaseOrchestrator.__subclasses__(): s = f"{orchestration_type} == {cls.kind}" if orchestration_type == cls.kind: return cls() # type: ignore raise TypeError(f"No orchestration type of kind {orchestration_type}")