pyrit/ui/rpc.py (156 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging import time from threading import Semaphore, Thread from typing import Callable, Optional from pyrit.models import PromptRequestPiece, Score from pyrit.ui.app import is_app_running, launch_app DEFAULT_PORT = 18812 logger = logging.getLogger(__name__) # Exceptions class RPCAppException(Exception): def __init__(self, message: str): super().__init__(message) class RPCAlreadyRunningException(RPCAppException): """ This exception is thrown when an RPC server is already running and the user tries to start another one. """ def __init__(self): super().__init__("RPC server is already running.") class RPCClientNotReadyException(RPCAppException): """ This exception is thrown when the RPC client is not ready to receive messages. """ def __init__(self): super().__init__("RPC client is not ready.") class RPCServerStoppedException(RPCAppException): """ This exception is thrown when the RPC server is stopped. """ def __init__(self): super().__init__("RPC server is stopped.") # RPC Server class AppRPCServer: import rpyc # RPC Service class RPCService(rpyc.Service): """ RPC service is the service that RPyC is using. RPC (Remote Procedure Call) is a way to interact with code that is hosted in another process or on an other machine. RPyC is a library that implements RPC and we are using to exchange information between PyRIT's main process and Gradio's process. This way the interface is independent of which PyRIT code is running the process. """ def __init__(self, *, score_received_semaphore: Semaphore, client_ready_semaphore: Semaphore): super().__init__() self._callback_score_prompt = None # type: Optional[Callable[[PromptRequestPiece, Optional[str]], None]] self._last_ping = None # type: Optional[float] self._scores_received = [] # type: list[Score] self._score_received_semaphore = score_received_semaphore self._client_ready_semaphore = client_ready_semaphore def on_connect(self, conn): logger.info("Client connected") def on_disconnect(self, conn): logger.info("Client disconnected") def exposed_receive_score(self, score: Score): logger.info(f"Score received: {score}") self._scores_received.append(score) self._score_received_semaphore.release() def exposed_receive_ping(self): # A ping should be received every 2s from the client. If a client misses a ping then the server should # stoped self._last_ping = time.time() logger.debug("Ping received") def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): self._callback_score_prompt = callback self._client_ready_semaphore.release() def is_client_ready(self): if self._callback_score_prompt is None: return False return True def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): if not self.is_client_ready(): raise RPCClientNotReadyException() self._callback_score_prompt(prompt, task) def is_ping_missed(self): if self._last_ping is None: return False return time.time() - self._last_ping > 2 def pop_score_received(self) -> Score | None: try: return self._scores_received.pop() except IndexError: return None def __init__(self, open_browser: bool = False): self._server = None self._server_thread = None self._rpc_service = None self._is_alive_thread = None self._is_alive_stop = False self._score_received_semaphore = None self._client_ready_semaphore = None self._server_is_running = False self._open_browser = open_browser def start(self): """ Attempt to start the RPC server. If the server is already running, this method will throw an exception. """ # Check if the server is already running by checking if the port is already in use. # If the port is already in use, throw an exception. if self._is_instance_running(): raise RPCAlreadyRunningException() self._score_received_semaphore = Semaphore(0) self._client_ready_semaphore = Semaphore(0) # Start the RPC server. self._rpc_service = self.RPCService( score_received_semaphore=self._score_received_semaphore, client_ready_semaphore=self._client_ready_semaphore ) self._server = self.rpyc.ThreadedServer( self._rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True} ) self._server_thread = Thread(target=self._server.start) self._server_thread.start() # Start a thread to check if the client is still alive self._is_alive_stop = False self._is_alive_thread = Thread(target=self._is_alive) self._is_alive_thread.start() self._server_is_running = True logger.info("RPC server started") if not is_app_running(): logger.info("Launching Gradio UI") launch_app(open_browser=self._open_browser) else: logger.info("Gradio UI is already running. Will not launch another instance.") def stop(self): """ Stop the RPC server and free up the listening port. """ self.stop_request() if self._server is not None: self._server_thread.join() if self._is_alive_thread is not None: self._is_alive_thread.join() logger.info("RPC server stopped") def stop_request(self): """ Request the RPC server to stop. This method is does not block while waiting for the server to stop. """ logger.info("RPC server stopping") if self._server is not None: self._server.close() self._server = None if self._is_alive_thread is not None: self._is_alive_stop = True self._server_is_running = False if self._client_ready_semaphore is not None: self._client_ready_semaphore.release() if self._score_received_semaphore is not None: self._score_received_semaphore.release() def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): """ Send a score prompt to the client. """ if self._rpc_service is None: raise RPCAppException("RPC server is not running.") self._rpc_service.send_score_prompt(prompt, task) def wait_for_score(self) -> Score: """ Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. """ if self._score_received_semaphore is None or self._rpc_service is None: raise RPCAppException("RPC server is not running.") self._score_received_semaphore.acquire() if not self._server_is_running: raise RPCServerStoppedException() score_ref = self._rpc_service.pop_score_received() self._client_ready_semaphore.release() if score_ref is None: return None # Pass instance variables of reflected RPyC Score object as args to PyRIT Score object score = Score( score_value=score_ref.score_value, score_type=score_ref.score_type, score_category=str(score_ref.score_category), score_value_description=score_ref.score_value_description, score_rationale=score_ref.score_rationale, score_metadata=score_ref.score_metadata, prompt_request_response_id=score_ref.prompt_request_response_id, ) return score def wait_for_client(self): """ Wait for the client to be ready to receive messages. """ if self._client_ready_semaphore is None: raise RPCAppException("RPC server is not running.") logger.info("Waiting for client to be ready") self._client_ready_semaphore.acquire() if not self._server_is_running: raise RPCServerStoppedException() logger.info("Client is ready") def _is_instance_running(self): """ Check if the RPC server is running. """ import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", DEFAULT_PORT)) == 0 def _is_alive(self): """ Check if a ping has been missed. If a ping has been missed, stop the server. """ while not self._is_alive_stop: if self._rpc_service.is_ping_missed(): logger.error("Ping missed. Stopping server.") self.stop_request() break time.sleep(1)