pyrit/ui/connection_status.py (42 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import gradio as gr from rpc_client import RPCClient class ConnectionStatusHandler: def __init__(self, is_connected_state: gr.State, rpc_client: RPCClient): self.state = is_connected_state self.server_disconnected = False self.rpc_client = rpc_client self.next_prompt = "" def setup(self, *, main_interface: gr.Column, loading_animation: gr.Column, next_prompt_state: gr.State): self.state.change( fn=self._on_state_change, inputs=[self.state], outputs=[main_interface, loading_animation, next_prompt_state], ) connection_status_timer = gr.Timer(1) connection_status_timer.tick(fn=self._check_connection_status, inputs=[self.state], outputs=[self.state]).then( fn=self._reconnect_if_needed, outputs=[self.state] ) def set_ready(self): self.server_disconnected = False def set_disconnected(self): self.server_disconnected = True def set_next_prompt(self, next_prompt: str): self.next_prompt = next_prompt def _on_state_change(self, is_connected: bool): print("Connection status changed to: ", is_connected, " - ", self.next_prompt) if is_connected: return [gr.Column(visible=True), gr.Row(visible=False), self.next_prompt] return [gr.Column(visible=False), gr.Row(visible=True), self.next_prompt] def _check_connection_status(self, is_connected: bool): if self.server_disconnected or not is_connected: print("Gradio disconnected") return False return True def _reconnect_if_needed(self): if self.server_disconnected: print("Attempting to reconnect") self.rpc_client.reconnect() next_prompt = self.rpc_client.wait_for_prompt() self.next_prompt = str(next_prompt.converted_value) self.server_disconnected = False return True