scripts/parser.py (202 lines of code) (raw):

import io import logging import requests from requests.adapters import HTTPAdapter import onnx from google.protobuf.internal.encoder import _VarintBytes from urllib3.util.retry import Retry from tqdm import tqdm logging.basicConfig(level=logging.INFO, format="%(message)s") MIN_CHUNK_SIZE = 8 * 1024 # 8 KB minimum per Range request to batch headers WANTED_TAGS = {1, 2, 11, 12, 13} # GraphProto: node(1), name(2), input(11), output(12), value metadata(13) class LoggingRetry(Retry): def increment(self, method, url, response=None, error=None, _pool=None, _stacktrace=None): new_retry = super().increment(method, url, response, error, _pool, _stacktrace) attempt = (self.total - new_retry.total + 1) if self.total is not None else 'unknown' delay = new_retry.get_backoff_time() err_str = f" Error: {error}" if error else "" sleep_str = f" Sleeping for {delay:.2f} seconds." if delay > 0 else "" logging.info(f"Retrying {method} request to {url}. Attempt {attempt}.{err_str}{sleep_str}") return new_retry class RangeFetcher: def __init__(self, url, retries=10, backoff_factor=2, timeout=10, progress_bar=None): self.url = url self.session = requests.Session() self.buffer = bytearray() self.loaded_ranges = [] # list of (start, end), inclusive self.total_downloaded = 0 self.timeout = timeout retry_strategy = LoggingRetry( total=retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["HEAD", "GET", "OPTIONS"] ) retry_strategy.logger = logging.getLogger("urllib3.retry") retry_strategy.logger.setLevel(logging.INFO) adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) self.session.mount("https://", adapter) self.progress_bar = progress_bar if progress_bar is not None else tqdm(desc="Downloading model", unit="B", unit_scale=True) def _add_range(self, start, data): """Insert `data` at offset `start` into self.buffer, merge loaded_ranges.""" end = start + len(data) - 1 if end >= len(self.buffer): self.buffer.extend(b'\x00' * (end + 1 - len(self.buffer))) self.buffer[start:end+1] = data # merge intervals new = (start, end) merged = [] i = 0 while i < len(self.loaded_ranges) and self.loaded_ranges[i][1] < new[0] - 1: merged.append(self.loaded_ranges[i]); i += 1 while i < len(self.loaded_ranges) and self.loaded_ranges[i][0] <= new[1] + 1: new = (min(new[0], self.loaded_ranges[i][0]), max(new[1], self.loaded_ranges[i][1])) i += 1 merged.append(new) merged.extend(self.loaded_ranges[i:]) self.loaded_ranges = merged def fetch(self, start, end=None): """ Ensure buffer[start..end] is loaded (inclusive). If end=None, do one fetch from `start` to EOF (no min-size). Otherwise, fetch only the missing sub-ranges, each >= MIN_CHUNK_SIZE. """ if end is None: # open‐ended final fetch headers = {"Range": f"bytes={start}-"} self.progress_bar.set_description(f"Fetching bytes {start}-EOF") resp = self.session.get(self.url, headers=headers); resp.raise_for_status() chunk = resp.content self.total_downloaded += len(chunk) self.progress_bar.update(len(chunk)) self._add_range(start, chunk) return self.buffer[start:start+len(chunk)] # find holes in [start..end] to_fetch = [] cursor = start for (a, b) in self.loaded_ranges: if b < cursor: continue if a > end: break if a > cursor: to_fetch.append((cursor, a-1)) cursor = max(cursor, b+1) if cursor <= end: to_fetch.append((cursor, end)) # fetch each hole (enforcing MIN_CHUNK_SIZE) for (s, e) in to_fetch: if (e - s + 1) < MIN_CHUNK_SIZE: e = s + MIN_CHUNK_SIZE - 1 headers = {"Range": f"bytes={s}-{e}"} self.progress_bar.set_description(f"Fetching bytes {s}-{e}") resp = self.session.get(self.url, headers=headers); resp.raise_for_status() chunk = resp.content self.total_downloaded += len(chunk) self.progress_bar.update(len(chunk)) self._add_range(s, chunk) return self.buffer[start:end+1] def close(self): """Clean up resources to prevent memory leaks.""" self.session.close() self.progress_bar.close() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() # Based on https://raw.githubusercontent.com/onnx/onnx/refs/heads/main/onnx/onnx.proto def read_varint(stream): """Read a Base-128 varint from `stream`. Returns (value, bytes_read).""" result = 0; shift = 0; count = 0 while True: b = stream.read(1) if not b: raise EOFError b = b[0] result |= (b & 0x7F) << shift count += 1 if not (b & 0x80): break shift += 7 return result, count def skip_field(stream, wire_type): """Advance `stream` past the next field of the given `wire_type`. Returns skipped bytes.""" if wire_type == 0: # varint _, n = read_varint(stream); return n elif wire_type == 1: # fixed64 stream.seek(8, io.SEEK_CUR); return 8 elif wire_type == 2: # length-delimited length, n = read_varint(stream) stream.seek(length, io.SEEK_CUR) return n + length elif wire_type == 5: # fixed32 stream.seek(4, io.SEEK_CUR); return 4 else: raise ValueError(f"Unsupported wire type {wire_type}") def locate_graph_header(fetcher, probe_size=1*1024*1024): """ Download the first `probe_size` bytes, scan for ModelProto.graph (field 7, wire 2). Returns (before_bytes, graph_payload_offset, graph_length). """ data = fetcher.fetch(0, probe_size-1) stream = io.BytesIO(data) before = bytearray() while True: key_pos = stream.tell() try: key, key_len = read_varint(stream) except EOFError: raise RuntimeError("Couldn't find graph header in initial bytes") field_num, wire_type = key >> 3, key & 0x7 if field_num == 7 and wire_type == 2: length, length_len = read_varint(stream) header_len = key_len + length_len payload_off = key_pos + header_len return data[:key_pos], payload_off, length # skip and accumulate everything before skipped = skip_field(stream, wire_type) stream.seek(key_pos + key_len + skipped) before.extend(data[key_pos:key_pos + key_len + skipped]) def extract_graph_structure(fetcher, graph_off, graph_len): """ Walk the GraphProto payload [graph_off..graph_off+graph_len) and collect only fields whose tag in WANTED_TAGS, skipping all others. Returns the concatenated bytes of just those wanted fields. """ out = bytearray() pos = graph_off end = graph_off + graph_len while pos < end: # fetch minimal header (varint key + possible length prefix) hdr = fetcher.fetch(pos, pos + 19) hdr_stream = io.BytesIO(hdr) key, key_len = read_varint(hdr_stream) field_num, wire_type = key >> 3, key & 0x7 length = 0; length_len = 0 if wire_type == 2: length, length_len = read_varint(hdr_stream) # compute total bytes for this field if wire_type == 2: total = key_len + length_len + length elif wire_type == 0: total = key_len + skip_field(io.BytesIO(hdr[key_len:]), wire_type) elif wire_type == 1: total = key_len + 8 elif wire_type == 5: total = key_len + 4 else: raise ValueError(f"Bad wire type {wire_type}") # if it’s not one of the wanted tags, skip it entirely if field_num not in WANTED_TAGS: pos += total continue # otherwise fetch & append it in one go chunk = fetcher.fetch(pos, pos + total - 1) out.extend(chunk) pos += total return bytes(out) def strip_data(graph, size_limit=1 * 1024 * 1024): # Remove initializers in the current graph. del graph.initializer[:] # Iterate over nodes to process any subgraphs. for node in graph.node: for attr in node.attribute: # If attribute holds a single subgraph. if attr.type == onnx.AttributeProto.GRAPH: strip_data(attr.g) # If attribute holds multiple subgraphs. elif attr.type == onnx.AttributeProto.GRAPHS: for subgraph in attr.graphs: strip_data(subgraph) if node.op_type == "Constant": # ONNX Constant nodes store their tensor under attribute 'value' for attr in node.attribute: if attr.name == "value" and attr.t is not None: tp = attr.t data_size = len(tp.raw_data) if tp.raw_data else 0 if data_size > size_limit: # Remove all data fields from TensorProto tp.ClearField("raw_data") tp.ClearField("float_data") tp.ClearField("int32_data") tp.ClearField("int64_data") tp.ClearField("double_data") tp.ClearField("uint64_data") tp.ClearField("string_data") # dims, data_type, and name are kept intact def stream_parse_model_header(url): # Create a new nested progress bar (position=1 for enabling nested bars) progress = tqdm(desc="Downloading model", unit="B", unit_scale=True, position=1) with RangeFetcher(url, progress_bar=progress) as fetcher: # 1) find graph payload before, graph_off, graph_len = locate_graph_header(fetcher) # 2) stream‐parse only node/name/input/output fields graph_struct = extract_graph_structure(fetcher, graph_off, graph_len) # 3) fetch any ModelProto fields after graph (small headers, metadata) after = fetcher.fetch(graph_off + graph_len, None) # 4) rebuild minimal ModelProto tag = _VarintBytes((7 << 3) | 2) length = _VarintBytes(len(graph_struct)) model_bytes = before + tag + length + graph_struct + after # 5) parse with ONNX model = onnx.ModelProto() model.ParseFromString(model_bytes) strip_data(model.graph) return model