evals/elsuite/multistep_web_tasks/webarena/browser_env/processors.py (495 lines of code) (raw):

import logging import re from collections import defaultdict from typing import Any, Optional, TypedDict import numpy as np import numpy.typing as npt from beartype import beartype from playwright.sync_api import CDPSession, Page, ViewportSize from evals.elsuite.multistep_web_tasks.webarena.browser_env.browser_utils import ( AccessibilityTree, AccTreeBrowserObservation, BrowserObservation, BrowserState, BrowserWindowConfig, Observation, ) from evals.elsuite.multistep_web_tasks.webarena.browser_env.constants import ( IGNORED_ACTREE_PROPERTIES, ) from evals.elsuite.multistep_web_tasks.webarena.core.playwright_api import ( ClientForwarder, PageForwarder, ) logger = logging.getLogger(__name__) class ObservationProcessor: def process(self, page: Page, client: CDPSession) -> Observation: raise NotImplementedError class ObservationMetadata(TypedDict): obs_nodes_info: dict[str, Any] def create_empty_metadata() -> ObservationMetadata: return { "obs_nodes_info": {}, } class TextObervationProcessor(ObservationProcessor): def __init__( self, observation_type: str, current_viewport_only: bool, viewport_size: ViewportSize, ): self.observation_type = observation_type self.current_viewport_only = current_viewport_only self.viewport_size = viewport_size self.observation_tag = "text" self.meta_data = create_empty_metadata() # use the store meta data of this observation type @beartype def fetch_browser_info( self, page: Page, client: CDPSession, ) -> BrowserState: # extract domtree tree = client.send( "DOMSnapshot.captureSnapshot", { "computedStyles": [], "includeDOMRects": True, "includePaintOrder": True, }, ) # calibrate the bounds, in some cases, the bounds are scaled somehow bounds = tree["documents"][0]["layout"]["bounds"] b = bounds[0] n = b[2] / self.viewport_size["width"] bounds = [[x / n for x in bound] for bound in bounds] tree["documents"][0]["layout"]["bounds"] = bounds # add union bound placeholder tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds] # extract browser info win_upper_bound = page.evaluate("window.pageYOffset") win_left_bound = page.evaluate("window.pageXOffset") win_width = page.evaluate("window.screen.width") win_height = page.evaluate("window.screen.height") win_right_bound = win_left_bound + win_width win_lower_bound = win_upper_bound + win_height device_pixel_ratio = page.evaluate("window.devicePixelRatio") assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0" config: BrowserWindowConfig = { "win_upper_bound": win_upper_bound, "win_left_bound": win_left_bound, "win_width": win_width, "win_height": win_height, "win_right_bound": win_right_bound, "win_lower_bound": win_lower_bound, "device_pixel_ratio": device_pixel_ratio, } # assert len(tree['documents']) == 1, "More than one document in the DOM tree" info: BrowserState = BrowserState({"DOMTree": tree, "config": config}) return info @beartype @staticmethod def partially_in_viewport(bound: list[float], config: BrowserWindowConfig) -> bool: [x, y, width, height] = bound elem_left_bound = x elem_top_bound = y elem_right_bound = x + width elem_lower_bound = y + height ok = ( elem_left_bound < config["win_right_bound"] and elem_right_bound >= config["win_left_bound"] and elem_top_bound < config["win_lower_bound"] and elem_lower_bound >= config["win_upper_bound"] ) return ok @beartype def retrieve_viewport_info(self, info: BrowserState) -> None: """Add viewport related information to the DOMTree 1. add union bound, which is a union of all the bounds of the nodes in the subtree This is only used when current_viewport_only is enabled since it is quite slow TODO[robert1003]: improve """ tree = info["DOMTree"] document = tree["documents"][0] nodes = document["nodes"] parent = nodes["parentIndex"] node_names = nodes["nodeName"] layout = document["layout"] layout_node_cursor = layout["nodeIndex"] bounds = layout["bounds"] graph = defaultdict(lambda: []) assert len(node_names) == len(parent) for node_idx in range(len(node_names)): parent_idx = parent[node_idx] if parent_idx != -1: graph[parent_idx].append(node_idx) union_bounds: list[Optional[list[float]]] = [None for _ in bounds] def valid_bbox(bound: Optional[list[float]]) -> bool: if bound is None: return False # no width or height if np.isclose(bound[2], 0): return False if np.isclose(bound[3], 0): return False return True def add_union_bound(idx: int) -> Optional[list[float]]: if idx in layout_node_cursor: cursor = layout_node_cursor.index(idx) node_bound = bounds[cursor].copy() tree_bounds: list[Any] = [node_bound] for child_idx in graph[idx]: child_bound = add_union_bound(child_idx) tree_bounds.append(child_bound.copy() if child_bound else None) tree_bounds = [b for b in tree_bounds if valid_bbox(b)] # convert to absolute coordinates for i in range(len(tree_bounds)): tree_bounds[i][2] = tree_bounds[i][0] + tree_bounds[i][2] tree_bounds[i][3] = tree_bounds[i][1] + tree_bounds[i][3] if len(tree_bounds) == 0: assert not valid_bbox(node_bound) node_union_bound = [0.0, 0.0, 0.0, 0.0] else: left_bound = min([b[0] for b in tree_bounds]) top_bound = min([b[1] for b in tree_bounds]) right_bound = max([b[2] for b in tree_bounds]) bottom_bound = max([b[3] for b in tree_bounds]) node_union_bound = [ left_bound, top_bound, right_bound - left_bound, bottom_bound - top_bound, ] # update the list union_bounds[cursor] = node_union_bound else: node_union_bound = None return node_union_bound add_union_bound(0) info["DOMTree"]["documents"][0]["layout"]["unionBounds"] = union_bounds @beartype def current_viewport_html(self, info: BrowserState) -> str: # adopted from [natbot](https://github.com/nat/natbot) tree = info["DOMTree"] strings = tree["strings"] document = tree["documents"][0] nodes = document["nodes"] attributes = nodes["attributes"] node_value = nodes["nodeValue"] parent = nodes["parentIndex"] node_names = nodes["nodeName"] layout = document["layout"] layout_node_cursor = layout["nodeIndex"] union_bounds = layout["unionBounds"] graph = defaultdict(lambda: []) for node_idx in range(len(node_names)): parent_idx = parent[node_idx] if parent_idx != -1: graph[parent_idx].append(node_idx) def dfs(idx: int) -> str: node_name = strings[node_names[idx]].lower().strip() can_skip = "#" in node_name or "::" in node_name inner_text = "" node_value_idx = node_value[idx] if node_value_idx >= 0 and node_value_idx < len(strings): inner_text = " ".join(strings[node_value_idx].split()) node_attributes = [strings[i] for i in attributes[idx]] node_attributes_str = "" for i in range(0, len(node_attributes), 2): a = node_attributes[i] b = node_attributes[i + 1] b = " ".join(b.split()) node_attributes_str += f'{a}="{b}" ' node_attributes_str = node_attributes_str.strip() html = "" if not can_skip: html += f"<{node_name}" if {node_attributes_str}: html += f" {node_attributes_str}" html += f">{inner_text}" else: html += f"{inner_text}" for child_idx in graph[idx]: if child_idx in layout_node_cursor: cursor = layout_node_cursor.index(child_idx) union_bound = union_bounds[cursor] if not self.partially_in_viewport(union_bound, info["config"]): continue html += dfs(child_idx) if not can_skip: html += f"</{node_name}>" return html html = dfs(0) return html @beartype def fetch_page_accessibility_tree( self, info: BrowserState, client: ClientForwarder ) -> AccessibilityTree: accessibility_tree: AccessibilityTree = client.send("Accessibility.getFullAXTree", {})[ "nodes" ] # a few nodes are repeated in the accessibility tree seen_ids = set() _accessibility_tree = [] for node in accessibility_tree: if node["nodeId"] not in seen_ids: _accessibility_tree.append(node) seen_ids.add(node["nodeId"]) accessibility_tree = _accessibility_tree # add the bounding box of each node tree = info["DOMTree"] document = tree["documents"][0] nodes = document["nodes"] backend_node_id = nodes["backendNodeId"] node_names = nodes["nodeName"] layout = document["layout"] layout_node_cursor = layout["nodeIndex"] bounds = layout["bounds"] union_bounds = layout["unionBounds"] offsetrect_bounds = layout["offsetRects"] backend_id_to_bound = {} # get the mapping between backend node id and bounding box for idx in range(len(node_names)): if idx not in layout_node_cursor: continue cursor = layout_node_cursor.index(idx) node_bound = bounds[cursor] node_union_bound = union_bounds[cursor] node_offsetrect_bound = offsetrect_bounds[cursor] node_backend_id = backend_node_id[idx] backend_id_to_bound[node_backend_id] = [ node_bound, node_union_bound, node_offsetrect_bound, ] parent_graph: dict[str, str] = {} refine_node_ids: list[str] = [] for node in accessibility_tree: if "parentId" in node: parent_graph[node["nodeId"]] = node["parentId"] if "backendDOMNodeId" not in node: node["bound"] = None node["union_bound"] = None node["offsetrect_bound"] = None elif node["backendDOMNodeId"] not in backend_id_to_bound: refine_node_ids.append(node["nodeId"]) else: node["bound"] = backend_id_to_bound[node["backendDOMNodeId"]][0] node["union_bound"] = backend_id_to_bound[node["backendDOMNodeId"]][1] node["offsetrect_bound"] = backend_id_to_bound[node["backendDOMNodeId"]][2] # refine the bounding box for nodes which only appear in the accessibility tree node_ids = [node["nodeId"] for node in accessibility_tree] for refine_node_id in refine_node_ids: child_id = refine_node_id parent_idx: Optional[int] = None while child_id in parent_graph: parent_id = parent_graph[child_id] parent_idx = node_ids.index(parent_id) child_id = parent_id if accessibility_tree[parent_idx]["union_bound"] is not None: break refine_node_idx = node_ids.index(refine_node_id) if parent_idx is not None: accessibility_tree[refine_node_idx]["bound"] = accessibility_tree[parent_idx][ "bound" ] accessibility_tree[refine_node_idx]["union_bound"] = accessibility_tree[parent_idx][ "union_bound" ] accessibility_tree[refine_node_idx]["offsetrect_bound"] = accessibility_tree[ parent_idx ]["offsetrect_bound"] else: accessibility_tree[refine_node_idx]["bound"] = None accessibility_tree[refine_node_idx]["union_bound"] = None accessibility_tree[refine_node_idx]["offsetrect_bound"] = None return accessibility_tree @beartype def current_viewport_accessibility_tree( self, info: BrowserState, accessibility_tree: AccessibilityTree, ) -> AccessibilityTree: config = info["config"] subtree = [] for node in accessibility_tree: if not node["union_bound"]: continue [x, y, width, height] = node["union_bound"] elem_left_bound = x elem_top_bound = y elem_right_bound = x + width elem_lower_bound = y + height ok = ( elem_left_bound < config["win_right_bound"] and elem_right_bound >= config["win_left_bound"] and elem_top_bound < config["win_lower_bound"] and elem_lower_bound >= config["win_upper_bound"] ) if ok: subtree.append(node) return subtree @beartype @staticmethod def parse_accessibility_tree( accessibility_tree: AccessibilityTree, ) -> tuple[str, dict[str, Any]]: """Parse the accessibility tree into a string text""" node_id_to_idx = {} for idx, node in enumerate(accessibility_tree): node_id_to_idx[node["nodeId"]] = idx obs_nodes_info = {} def dfs(idx: int, obs_node_id: str, depth: int) -> str: tree_str = "" node = accessibility_tree[idx] indent = "\t" * depth valid_node = True try: role = node["role"]["value"] name = node["name"]["value"] node_str = f"[{obs_node_id}] {role} {repr(name)}" properties = [] for property in node.get("properties", []): try: if property["name"] in IGNORED_ACTREE_PROPERTIES: continue properties.append(f'{property["name"]}: {property["value"]["value"]}') except KeyError: pass if properties: node_str += " " + " ".join(properties) # check valid if not node_str.strip(): valid_node = False # empty generic node if not name.strip(): if not properties: if role in [ "generic", "img", "list", "strong", "paragraph", "banner", "navigation", "Section", "LabelText", "Legend", "listitem", ]: valid_node = False elif role in ["listitem"]: valid_node = False if valid_node: tree_str += f"{indent}{node_str}" obs_nodes_info[obs_node_id] = { "backend_id": node["backendDOMNodeId"], "bound": node["bound"], "union_bound": node["union_bound"], "offsetrect_bound": node["offsetrect_bound"], "text": node_str, } except Exception: valid_node = False for _, child_node_id in enumerate(node["childIds"]): if child_node_id not in node_id_to_idx: continue # mark this to save some tokens child_depth = depth + 1 if valid_node else depth child_str = dfs(node_id_to_idx[child_node_id], child_node_id, child_depth) if child_str.strip(): if tree_str.strip(): tree_str += "\n" tree_str += child_str return tree_str if len(accessibility_tree) == 0: logger.warning("Empty accessibility tree") return "", obs_nodes_info else: tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0) return tree_str, obs_nodes_info @beartype @staticmethod def clean_accesibility_tree(tree_str: str) -> str: """further clean accesibility tree""" clean_lines: list[str] = [] for line in tree_str.split("\n"): if "statictext" in line.lower(): prev_lines = clean_lines[-3:] pattern = r"\[\d+\] StaticText '([^']+)'" match = re.search(pattern, line) if match: static_text = match.group(1) if all(static_text not in prev_line for prev_line in prev_lines): clean_lines.append(line) else: clean_lines.append(line) return "\n".join(clean_lines) @beartype def process(self, page: PageForwarder, client: ClientForwarder) -> dict[str, str]: # get the tab info tab_title_str = page.title() # TODO: support multiple tabs, e.g. something like: # open_tabs = page.context.pages # try: # tab_titles = [tab.title() for tab in open_tabs] # current_tab_idx = open_tabs.index(page) # for idx in range(len(open_tabs)): # if idx == current_tab_idx: # tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}" # else: # tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}" # tab_title_str = " | ".join(tab_titles) # except Exception: # tab_title_str = " | ".join(["Tab {idx}" for idx in range(len(open_tabs))]) try: browser_info = page.fetch_browser_info() except Exception: page.wait_for_load_state("load", timeout=500) browser_info = page.fetch_browser_info() if self.current_viewport_only: self.retrieve_viewport_info(browser_info) # get html content if self.current_viewport_only: html = self.current_viewport_html(browser_info) html_content = html else: html_content = page.content() # get acctree content accessibility_tree = self.fetch_page_accessibility_tree(browser_info, client) if self.current_viewport_only: accessibility_tree = self.current_viewport_accessibility_tree( browser_info, accessibility_tree ) acctree_content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree) acctree_content = self.clean_accesibility_tree(acctree_content) self.obs_nodes_info = obs_nodes_info self.meta_data["obs_nodes_info"] = obs_nodes_info self.browser_config = browser_info["config"] html_content = f"{tab_title_str}\n\n{html_content}" acctree_content = f"{tab_title_str}\n\n{acctree_content}" return {"html": html_content, "acctree": acctree_content} @beartype def get_element_center(self, element_id: str) -> tuple[float, float]: node_info = self.obs_nodes_info[element_id] node_bound = node_info["bound"] x, y, width, height = node_bound browser_config = self.browser_config b_x, b_y = ( browser_config["win_left_bound"], browser_config["win_upper_bound"], ) center_x = (x - b_x) + width / 2 center_y = (y - b_y) + height / 2 return ( center_x / self.viewport_size["width"], center_y / self.viewport_size["height"], ) class ImageObservationProcessor(ObservationProcessor): def __init__(self, observation_type: str): self.observation_type = observation_type self.observation_tag = "image" self.meta_data = create_empty_metadata() def process(self, page: PageForwarder, client: ClientForwarder) -> npt.NDArray[np.uint8]: raise NotImplementedError("TODO: Images with flask-playwright api") class ObservationHandler: """Main entry point to access all observation processor""" def __init__( self, main_observation_type: str, text_observation_type: str, image_observation_type: str, current_viewport_only: bool, viewport_size: ViewportSize, ) -> None: self.main_observation_type = main_observation_type self.text_processor = TextObervationProcessor( text_observation_type, current_viewport_only, viewport_size ) self.image_processor = ImageObservationProcessor(image_observation_type) self.viewport_size = viewport_size @beartype def get_observation_space(self) -> type[BrowserObservation]: return BrowserObservation @beartype def get_observation(self, page: PageForwarder, client: ClientForwarder) -> BrowserObservation: obs_dict = self.text_processor.process(page, client) # NOTE: no image obs with PageForwarder yet # image_obs = self.image_processor.process(page, client) image_obs = None # TODO: stop hardcoding AccTree here obs = AccTreeBrowserObservation( html=obs_dict["html"], acctree=obs_dict["acctree"], image=image_obs ) return obs @beartype def get_observation_metadata(self) -> dict[str, ObservationMetadata]: return { "text": self.text_processor.meta_data, "image": self.image_processor.meta_data, } @property def action_processor(self) -> ObservationProcessor: """Return the main processor that is associated with the action space""" if self.main_observation_type == "text": return self.text_processor elif self.main_observation_type == "image": return self.image_processor else: raise ValueError("Invalid main observation type")