pyrit/ui/rpc_client.py (89 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import socket import time from threading import Event, Semaphore, Thread from typing import Callable, Optional import rpyc from pyrit.models import PromptRequestPiece, Score from pyrit.ui.rpc import RPCAppException DEFAULT_PORT = 18812 class RPCClientStoppedException(RPCAppException): """ This exception is thrown when the RPC client is stopped. """ def __init__(self): super().__init__("RPC client is stopped.") class RPCClient: def __init__(self, callback_disconnected: Optional[Callable] = None): self._c = None # type: Optional[rpyc.Connection] self._bgsrv = None # type: Optional[rpyc.BgServingThread] self._ping_thread = None # type: Optional[Thread] self._bgsrv_thread = None # type: Optional[Thread] self._is_running = False self._shutdown_event = None # type: Optional[Event] self._prompt_received_sem = None # type: Optional[Semaphore] self._prompt_received = None # type: Optional[PromptRequestPiece] self._callback_disconnected = callback_disconnected def start(self): # Check if the port is open self._wait_for_server_avaible() self._prompt_received_sem = Semaphore(0) self._c = rpyc.connect("localhost", DEFAULT_PORT, config={"allow_public_attrs": True}) self._is_running = True self._shutdown_event = Event() self._bgsrv_thread = Thread(target=self._bgsrv_lifecycle) self._bgsrv_thread.start() def wait_for_prompt(self) -> PromptRequestPiece: self._prompt_received_sem.acquire() if self._is_running: return self._prompt_received raise RPCClientStoppedException() def send_prompt_response(self, response: bool): score = Score( score_value=str(response), score_type="true_false", score_category="safety", score_value_description="Safe" if response else "Unsafe", score_rationale="The prompt was marked safe" if response else "The prompt was marked unsafe", score_metadata=None, prompt_request_response_id=self._prompt_received.id, ) self._c.root.receive_score(score) def _wait_for_server_avaible(self): # Wait for the server to be available while not self._is_server_running(): print("Server is not running. Waiting for server to start...") time.sleep(1) def stop(self): """ Stop the client. """ # Send a signal to the thread to stop self._shutdown_event.set() if self._bgsrv_thread is not None: self._bgsrv_thread.join() def reconnect(self): """ Reconnect to the server. """ self.stop() print("Reconnecting to server...") self.start() def _receive_prompt(self, prompt_request: PromptRequestPiece, task: Optional[str] = None): print(f"Received prompt: {prompt_request}") self._prompt_received = prompt_request self._prompt_received_sem.release() def _ping(self): try: while self._is_running: self._c.root.receive_ping() time.sleep(1.5) if not self._is_running: print("Connection closed") if self._callback_disconnected is not None: self._callback_disconnected() except EOFError: print("Connection closed") if self._callback_disconnected is not None: self._callback_disconnected() def _bgsrv_lifecycle(self): self._bgsrv = rpyc.BgServingThread(self._c) self._ping_thread = Thread(target=self._ping) self._ping_thread.start() # Register callback self._c.root.callback_score_prompt(self._receive_prompt) # Wait for the server to be disconnected self._shutdown_event.wait() self._is_running = False # Release the semaphore in case it was waiting self._prompt_received_sem.release() self._ping_thread.join() # Avoid calling stop() twice if the server is already stopped. This can happen if the server is stopped # by the ping request. if self._bgsrv._active: self._bgsrv.stop() def _is_server_running(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", DEFAULT_PORT)) == 0