pyrit/prompt_target/openai/openai_realtime_target.py (213 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import base64
import json
import logging
import wave
from typing import Literal, Optional
from urllib.parse import urlencode
import websockets
from pyrit.models import PromptRequestResponse
from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.models.prompt_request_response import construct_response_from_request
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute
logger = logging.getLogger(__name__)
RealTimeVoice = Literal["alloy", "echo", "shimmer"]
class RealtimeTarget(OpenAITarget):
def __init__(
self,
*,
api_version: str = "2024-10-01-preview",
system_prompt: Optional[str] = "You are a helpful AI assistant",
voice: Optional[RealTimeVoice] = None,
existing_convo: Optional[dict] = {},
**kwargs,
) -> None:
"""
RealtimeTarget class for Azure OpenAI Realtime API.
Read more at https://learn.microsoft.com/en-us/azure/ai-services/openai/realtime-audio-reference
and https://platform.openai.com/docs/guides/realtime-websocket
Args:
model_name (str, Optional): The name of the model.
endpoint (str, Optional): The target URL for the OpenAI service.
api_key (str, Optional): The API key for accessing the Azure OpenAI service.
Defaults to the OPENAI_CHAT_KEY environment variable.
headers (str, Optional): Headers of the endpoint (JSON).
use_aad_auth (bool, Optional): When set to True, user authentication is used
instead of API Key. DefaultAzureCredential is taken for
https://cognitiveservices.azure.com/.default . Please run `az login` locally
to leverage user AuthN.
api_version (str, Optional): The version of the Azure OpenAI API. Defaults to
"2024-06-01".
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit. The number of requests sent to the target
will be capped at the value provided.
api_version (str, Optional): The version of the Azure OpenAI API. Defaults to "2024-10-01-preview".
system_prompt (str, Optional): The system prompt to use. Defaults to "You are a helpful AI assistant".
voice (literal str, Optional): The voice to use. Defaults to None.
the only supported voices by the AzureOpenAI Realtime API are "alloy", "echo", and "shimmer".
existing_convo (dict[str, websockets.WebSocketClientProtocol], Optional): Existing conversations.
httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the
httpx.AsyncClient() constructor.
For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180}
"""
super().__init__(api_version=api_version, **kwargs)
self.system_prompt = system_prompt
self.voice = voice
self._existing_conversation = existing_convo
def _set_openai_env_configuration_vars(self):
self.model_name_environment_variable = "OPENAI_REALTIME_MODEL"
self.endpoint_environment_variable = "AZURE_OPENAI_REALTIME_ENDPOINT"
self.api_key_environment_variable = "OPENAI_REALTIME_API_KEY"
async def connect(self):
"""
Connects to Realtime API Target using websockets.
Returns the WebSocket connection.
"""
logger.info(f"Connecting to WebSocket: {self._endpoint}")
query_params = {
"deployment": self._model_name,
"OpenAI-Beta": "realtime=v1",
}
self._add_auth_param_to_query_params(query_params)
if self._api_version is not None:
query_params["api-version"] = self._api_version
url = f"{self._endpoint}?{urlencode(query_params)}"
websocket = await websockets.connect(url)
logger.info("Successfully connected to AzureOpenAI Realtime API")
return websocket
def _add_auth_param_to_query_params(self, query_params: dict) -> None:
"""
Adds the authentication parameter to the query parameters. This is how
Realtime API works, it doesn't use the headers for auth.
Args:
query_params (dict): The query parameters.
"""
if self._api_key:
query_params["api-key"] = self._api_key
if self._azure_auth:
query_params["access_token"] = self._azure_auth.refresh_token()
def _set_system_prompt_and_config_vars(self):
session_config = {
"modalities": ["audio", "text"],
"instructions": self.system_prompt,
"input_audio_format": "pcm16",
"output_audio_format": "pcm16",
"turn_detection": None,
}
if self.voice:
session_config["voice"] = self.voice
return session_config
async def send_event(self, event: dict, conversation_id: str):
"""
Sends an event to the WebSocket server.
Args:
event (dict): Event to send in dictionary format.
conversation_id (str): Conversation ID
"""
websocket = self._existing_conversation.get(conversation_id)
if websocket is None:
logger.error("WebSocket connection is not established")
raise Exception("WebSocket connection is not established")
await websocket.send(json.dumps(event))
logger.debug(f"Event sent - type: {event['type']}")
async def send_config(self, conversation_id: str):
"""
Sends the session configuration to the WebSocket server.
Args:
conversation_id (str): Conversation ID
"""
config_variables = self._set_system_prompt_and_config_vars()
await self.send_event(
event={"type": "session.update", "session": config_variables}, conversation_id=conversation_id
)
logger.info("Session set up")
@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
convo_id = prompt_request.request_pieces[0].conversation_id
if convo_id not in self._existing_conversation:
websocket = await self.connect()
self._existing_conversation[convo_id] = websocket
self.set_system_prompt(
system_prompt=self.system_prompt,
conversation_id=convo_id,
orchestrator_identifier=self.get_identifier(),
)
websocket = self._existing_conversation[convo_id]
self._validate_request(prompt_request=prompt_request)
await self.send_config(conversation_id=convo_id)
request = prompt_request.request_pieces[0]
response_type = request.converted_value_data_type
# Order of messages sent varies based on the data format of the prompt
if response_type == "audio_path":
output_audio_path, events = await self.send_audio_async(
filename=request.converted_value, conversation_id=convo_id
)
elif response_type == "text":
output_audio_path, events = await self.send_text_async(
text=request.converted_value, conversation_id=convo_id
)
text_response_piece = construct_response_from_request(
request=request, response_text_pieces=[events[1]], response_type="text"
).request_pieces[0]
audio_response_piece = construct_response_from_request(
request=request, response_text_pieces=[output_audio_path], response_type="audio_path"
).request_pieces[0]
response_entry = PromptRequestResponse(request_pieces=[text_response_piece, audio_response_piece])
return response_entry
async def save_audio(
self,
audio_bytes: bytes,
num_channels: int = 1,
sample_width: int = 2,
sample_rate: int = 16000,
output_filename: str = None,
) -> str:
"""
Saves audio bytes to a WAV file.
Args:
audio_bytes (bytes): Audio bytes to save.
num_channels (int): Number of audio channels. Defaults to 1 for the PCM16 format
sample_width (int): Sample width in bytes. Defaults to 2 for the PCM16 format
sample_rate (int): Sample rate in Hz. Defaults to 16000 Hz for the PCM16 format
output_filename (str): Output filename. If None, a UUID filename will be used.
Returns:
str: The path to the saved audio file.
"""
data = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path")
await data.save_formatted_audio(
data=audio_bytes,
output_filename=output_filename,
num_channels=num_channels,
sample_width=sample_width,
sample_rate=sample_rate,
)
return data.value
async def cleanup_target(self):
"""
Disconnects from the WebSocket server to clean up, cleaning up all existing conversations.
"""
for conversation_id, websocket in self._existing_conversation.items():
if websocket:
await websocket.close()
logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
self._existing_conversation = {}
async def cleanup_conversation(self, conversation_id: str):
"""
Disconnects from the WebSocket server for a specific conversation
"""
websocket = self._existing_conversation.get(conversation_id)
if websocket:
await websocket.close()
logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
del self._existing_conversation[conversation_id]
async def send_response_create(self, conversation_id: str):
"""
Sends response.create message to the WebSocket server.
"""
await self.send_event(event={"type": "response.create"}, conversation_id=conversation_id)
async def receive_events(self, conversation_id: str) -> list:
"""
Continuously receive events from the WebSocket server.
Args:
conversation_id: conversation ID
"""
websocket = self._existing_conversation[conversation_id]
if websocket is None: # change this to existing_conversation.websocket
logger.error("WebSocket connection is not established")
raise Exception("WebSocket connection is not established")
audio_transcript = None
audio_buffer = b""
conversation_messages = []
try:
async for message in websocket:
event = json.loads(message)
msg_response_type = event.get("type")
if msg_response_type:
if msg_response_type == "response.done":
logger.debug(f"event is: {json.dumps(event, indent=2)}")
audio_transcript = event["response"]["output"][0]["content"][0]["transcript"]
conversation_messages.append(audio_transcript)
break
elif msg_response_type == "error":
logger.error(f"Error, event is: {json.dumps(event, indent=2)}")
break
elif msg_response_type == "response.audio.delta":
# Append audio data to buffer
audio_data = base64.b64decode(event["delta"])
audio_buffer += audio_data
logger.debug("Audio data appended to buffer")
elif msg_response_type == "response.audio.done":
logger.debug(f"event is: {json.dumps(event, indent=2)}")
conversation_messages.append(audio_buffer)
else:
logger.debug(f"event is: {json.dumps(event, indent=2)}")
except websockets.ConnectionClosed as e:
logger.error(f"WebSocket connection closed: {e}")
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
return conversation_messages
async def send_text_async(self, text: str, conversation_id: str):
"""
Sends text prompt to the WebSocket server.
Args:
text: prompt to send.
conversation_id: conversation ID
"""
await self.send_response_create(conversation_id=conversation_id)
# Listen for responses
receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))
logger.info(f"Sending text message: {text}")
event = {
"type": "conversation.item.create",
"item": {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]},
}
await self.send_event(event=event, conversation_id=conversation_id)
events = await receive_tasks # Wait for all responses to be received
output_audio_path = await self.save_audio(events[0])
return output_audio_path, events
async def send_audio_async(self, filename: str, conversation_id: str):
"""
Send an audio message to the WebSocket server.
Args:
filename (str): The path to the audio file.
"""
with wave.open(filename, "rb") as wav_file:
# Read WAV parameters
num_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth() # Should be 2 bytes for PCM16
frame_rate = wav_file.getframerate()
num_frames = wav_file.getnframes()
audio_content = wav_file.readframes(num_frames)
receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))
try:
audio_base64 = base64.b64encode(audio_content).decode("utf-8")
event = {"type": "input_audio_buffer.append", "audio": audio_base64}
# await asyncio.sleep(0.1)
await self.send_event(event=event, conversation_id=conversation_id)
except Exception as e:
logger.info(f"Error sending audio: {e}")
return
event = {"type": "input_audio_buffer.commit"}
await asyncio.sleep(0.1)
await self.send_event(event, conversation_id=conversation_id)
await self.send_response_create(conversation_id=conversation_id) # Sends response.create message
responses = await receive_tasks
output_audio_path = await self.save_audio(responses[0], num_channels, sample_width, frame_rate)
return output_audio_path, responses
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
"""Validates the structure and content of a prompt request for compatibility of this target.
Args:
prompt_request (PromptRequestResponse): The prompt request response object.
Raises:
ValueError: If more than two request pieces are provided.
ValueError: If any of the request pieces have a data type other than 'text' or 'audio_path'.
"""
# Check the number of request pieces
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports one request piece.")
if prompt_request.request_pieces[0].converted_value_data_type not in ["text", "audio_path"]:
raise ValueError("This target only supports text and audio_path prompt input.")
def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
return False