pyrit/ui/scorer.py (83 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import gradio as gr import webview from connection_status import ConnectionStatusHandler from rpc_client import RPCClient GRADIO_POLLING_RATE = 0.5 # Polling Rate by the Gradio UI class GradioApp: def __init__(self): self.i = 0 self.rpc_client = RPCClient(self._disconnected_rpc_callback) self.connect_status = None self.url = "" def start_gradio(self, open_browser=False): with gr.Blocks() as demo: is_connected = gr.State(False) next_prompt_state = gr.State("") self.connect_status = ConnectionStatusHandler(is_connected, self.rpc_client) with gr.Column(visible=False) as main_interface: prompt = gr.Markdown("Prompt: ") prompt.height = "200px" with gr.Row(): safe = gr.Button("Safe") unsafe = gr.Button("Unsafe") safe.click( fn=lambda: [gr.update(interactive=False)] * 2 + [""], outputs=[safe, unsafe, next_prompt_state] ).then(fn=self._safe_clicked, outputs=next_prompt_state) unsafe.click( fn=lambda: [gr.update(interactive=False)] * 2 + [""], outputs=[safe, unsafe, next_prompt_state] ).then(fn=self._unsafe_clicked, outputs=next_prompt_state) with gr.Row() as loading_animation: loading_text = gr.Markdown("Connecting to PyRIT") timer = gr.Timer(GRADIO_POLLING_RATE) timer.tick(fn=self._loading_dots, outputs=loading_text) next_prompt_state.change( fn=self._on_next_prompt_change, inputs=[next_prompt_state], outputs=[prompt, safe, unsafe] ) self.connect_status.setup( main_interface=main_interface, loading_animation=loading_animation, next_prompt_state=next_prompt_state ) demo.load( fn=self._main_interface_loaded, outputs=[main_interface, loading_animation, next_prompt_state, is_connected], ) if open_browser: demo.launch(inbrowser=True) else: _, url, _ = demo.launch(prevent_thread_lock=True) self.url = url print("Gradio launched") webview.create_window("PyRIT - Scorer", self.url) webview.start() print("Webview closed!") if self.rpc_client: self.rpc_client.stop() def _safe_clicked(self): return self._send_prompt_response(True) def _unsafe_clicked(self): return self._send_prompt_response(False) def _send_prompt_response(self, value): self.rpc_client.send_prompt_response(value) prompt_request = self.rpc_client.wait_for_prompt() return str(prompt_request.converted_value) def _on_next_prompt_change(self, next_prompt): if next_prompt == "": return [ gr.Markdown("Waiting for next prompt..."), gr.update(interactive=False), gr.update(interactive=False), ] return [gr.Markdown("Prompt: " + next_prompt), gr.update(interactive=True), gr.update(interactive=True)] def _loading_dots(self): self.i = (self.i + 1) % 4 return gr.Markdown("Connecting to PyRIT" + "." * self.i) def _disconnected_rpc_callback(self): self.connect_status.set_disconnected() def _main_interface_loaded(self): print("Showing main interface") self.rpc_client.start() prompt_request = self.rpc_client.wait_for_prompt() next_prompt = str(prompt_request.converted_value) self.connect_status.set_next_prompt(next_prompt) self.connect_status.set_ready() print("PyRIT connected") return [gr.Column(visible=True), gr.Row(visible=False), next_prompt, True]