pyrit/ui/rpc.py (156 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import time
from threading import Semaphore, Thread
from typing import Callable, Optional
from pyrit.models import PromptRequestPiece, Score
from pyrit.ui.app import is_app_running, launch_app
DEFAULT_PORT = 18812
logger = logging.getLogger(__name__)
# Exceptions
class RPCAppException(Exception):
def __init__(self, message: str):
super().__init__(message)
class RPCAlreadyRunningException(RPCAppException):
"""
This exception is thrown when an RPC server is already running and the user tries to start another one.
"""
def __init__(self):
super().__init__("RPC server is already running.")
class RPCClientNotReadyException(RPCAppException):
"""
This exception is thrown when the RPC client is not ready to receive messages.
"""
def __init__(self):
super().__init__("RPC client is not ready.")
class RPCServerStoppedException(RPCAppException):
"""
This exception is thrown when the RPC server is stopped.
"""
def __init__(self):
super().__init__("RPC server is stopped.")
# RPC Server
class AppRPCServer:
import rpyc
# RPC Service
class RPCService(rpyc.Service):
"""
RPC service is the service that RPyC is using. RPC (Remote Procedure Call) is a way to interact with code that
is hosted in another process or on an other machine. RPyC is a library that implements RPC and we are using to
exchange information between PyRIT's main process and Gradio's process. This way the interface is
independent of which PyRIT code is running the process.
"""
def __init__(self, *, score_received_semaphore: Semaphore, client_ready_semaphore: Semaphore):
super().__init__()
self._callback_score_prompt = None # type: Optional[Callable[[PromptRequestPiece, Optional[str]], None]]
self._last_ping = None # type: Optional[float]
self._scores_received = [] # type: list[Score]
self._score_received_semaphore = score_received_semaphore
self._client_ready_semaphore = client_ready_semaphore
def on_connect(self, conn):
logger.info("Client connected")
def on_disconnect(self, conn):
logger.info("Client disconnected")
def exposed_receive_score(self, score: Score):
logger.info(f"Score received: {score}")
self._scores_received.append(score)
self._score_received_semaphore.release()
def exposed_receive_ping(self):
# A ping should be received every 2s from the client. If a client misses a ping then the server should
# stoped
self._last_ping = time.time()
logger.debug("Ping received")
def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]):
self._callback_score_prompt = callback
self._client_ready_semaphore.release()
def is_client_ready(self):
if self._callback_score_prompt is None:
return False
return True
def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None):
if not self.is_client_ready():
raise RPCClientNotReadyException()
self._callback_score_prompt(prompt, task)
def is_ping_missed(self):
if self._last_ping is None:
return False
return time.time() - self._last_ping > 2
def pop_score_received(self) -> Score | None:
try:
return self._scores_received.pop()
except IndexError:
return None
def __init__(self, open_browser: bool = False):
self._server = None
self._server_thread = None
self._rpc_service = None
self._is_alive_thread = None
self._is_alive_stop = False
self._score_received_semaphore = None
self._client_ready_semaphore = None
self._server_is_running = False
self._open_browser = open_browser
def start(self):
"""
Attempt to start the RPC server. If the server is already running, this method will throw an exception.
"""
# Check if the server is already running by checking if the port is already in use.
# If the port is already in use, throw an exception.
if self._is_instance_running():
raise RPCAlreadyRunningException()
self._score_received_semaphore = Semaphore(0)
self._client_ready_semaphore = Semaphore(0)
# Start the RPC server.
self._rpc_service = self.RPCService(
score_received_semaphore=self._score_received_semaphore, client_ready_semaphore=self._client_ready_semaphore
)
self._server = self.rpyc.ThreadedServer(
self._rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}
)
self._server_thread = Thread(target=self._server.start)
self._server_thread.start()
# Start a thread to check if the client is still alive
self._is_alive_stop = False
self._is_alive_thread = Thread(target=self._is_alive)
self._is_alive_thread.start()
self._server_is_running = True
logger.info("RPC server started")
if not is_app_running():
logger.info("Launching Gradio UI")
launch_app(open_browser=self._open_browser)
else:
logger.info("Gradio UI is already running. Will not launch another instance.")
def stop(self):
"""
Stop the RPC server and free up the listening port.
"""
self.stop_request()
if self._server is not None:
self._server_thread.join()
if self._is_alive_thread is not None:
self._is_alive_thread.join()
logger.info("RPC server stopped")
def stop_request(self):
"""
Request the RPC server to stop. This method is does not block while waiting for the server to stop.
"""
logger.info("RPC server stopping")
if self._server is not None:
self._server.close()
self._server = None
if self._is_alive_thread is not None:
self._is_alive_stop = True
self._server_is_running = False
if self._client_ready_semaphore is not None:
self._client_ready_semaphore.release()
if self._score_received_semaphore is not None:
self._score_received_semaphore.release()
def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None):
"""
Send a score prompt to the client.
"""
if self._rpc_service is None:
raise RPCAppException("RPC server is not running.")
self._rpc_service.send_score_prompt(prompt, task)
def wait_for_score(self) -> Score:
"""
Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will
return None.
"""
if self._score_received_semaphore is None or self._rpc_service is None:
raise RPCAppException("RPC server is not running.")
self._score_received_semaphore.acquire()
if not self._server_is_running:
raise RPCServerStoppedException()
score_ref = self._rpc_service.pop_score_received()
self._client_ready_semaphore.release()
if score_ref is None:
return None
# Pass instance variables of reflected RPyC Score object as args to PyRIT Score object
score = Score(
score_value=score_ref.score_value,
score_type=score_ref.score_type,
score_category=str(score_ref.score_category),
score_value_description=score_ref.score_value_description,
score_rationale=score_ref.score_rationale,
score_metadata=score_ref.score_metadata,
prompt_request_response_id=score_ref.prompt_request_response_id,
)
return score
def wait_for_client(self):
"""
Wait for the client to be ready to receive messages.
"""
if self._client_ready_semaphore is None:
raise RPCAppException("RPC server is not running.")
logger.info("Waiting for client to be ready")
self._client_ready_semaphore.acquire()
if not self._server_is_running:
raise RPCServerStoppedException()
logger.info("Client is ready")
def _is_instance_running(self):
"""
Check if the RPC server is running.
"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", DEFAULT_PORT)) == 0
def _is_alive(self):
"""
Check if a ping has been missed. If a ping has been missed, stop the server.
"""
while not self._is_alive_stop:
if self._rpc_service.is_ping_missed():
logger.error("Ping missed. Stopping server.")
self.stop_request()
break
time.sleep(1)