agent/agent.py (113 lines of code) (raw):
from computers import Computer
from utils import (
create_response,
show_image,
pp,
sanitize_message,
check_blocklisted_url,
)
import json
from typing import Callable
class Agent:
"""
A sample agent class that can be used to interact with a computer.
(See simple_cua_loop.py for a simple example without an agent.)
"""
def __init__(
self,
model="computer-use-preview",
computer: Computer = None,
tools: list[dict] = [],
acknowledge_safety_check_callback: Callable = lambda: False,
):
self.model = model
self.computer = computer
self.tools = tools
self.print_steps = True
self.debug = False
self.show_images = False
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
if computer:
dimensions = computer.get_dimensions()
self.tools += [
{
"type": "computer-preview",
"display_width": dimensions[0],
"display_height": dimensions[1],
"environment": computer.get_environment(),
},
]
def debug_print(self, *args):
if self.debug:
pp(*args)
def handle_item(self, item):
"""Handle each item; may cause a computer action + screenshot."""
if item["type"] == "message":
if self.print_steps:
print(item["content"][0]["text"])
if item["type"] == "function_call":
name, args = item["name"], json.loads(item["arguments"])
if self.print_steps:
print(f"{name}({args})")
if hasattr(self.computer, name): # if function exists on computer, call it
method = getattr(self.computer, name)
method(**args)
return [
{
"type": "function_call_output",
"call_id": item["call_id"],
"output": "success", # hard-coded output for demo
}
]
if item["type"] == "computer_call":
action = item["action"]
action_type = action["type"]
action_args = {k: v for k, v in action.items() if k != "type"}
if self.print_steps:
print(f"{action_type}({action_args})")
method = getattr(self.computer, action_type)
method(**action_args)
screenshot_base64 = self.computer.screenshot()
if self.show_images:
show_image(screenshot_base64)
# if user doesn't ack all safety checks exit with error
pending_checks = item.get("pending_safety_checks", [])
for check in pending_checks:
message = check["message"]
if not self.acknowledge_safety_check_callback(message):
raise ValueError(
f"Safety check failed: {message}. Cannot continue with unacknowledged safety checks."
)
call_output = {
"type": "computer_call_output",
"call_id": item["call_id"],
"acknowledged_safety_checks": pending_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# additional URL safety checks for browser environments
if self.computer.get_environment() == "browser":
current_url = self.computer.get_current_url()
check_blocklisted_url(current_url)
call_output["output"]["current_url"] = current_url
return [call_output]
return []
def run_full_turn(
self, input_items, print_steps=True, debug=False, show_images=False
):
self.print_steps = print_steps
self.debug = debug
self.show_images = show_images
new_items = []
# keep looping until we get a final response
while new_items[-1].get("role") != "assistant" if new_items else True:
self.debug_print([sanitize_message(msg) for msg in input_items + new_items])
response = create_response(
model=self.model,
input=input_items + new_items,
tools=self.tools,
truncation="auto",
)
self.debug_print(response)
if "output" not in response and self.debug:
print(response)
raise ValueError("No output from model")
else:
new_items += response["output"]
for item in response["output"]:
new_items += self.handle_item(item)
return new_items