import io
import logging
import requests
from requests.adapters import HTTPAdapter
import onnx
from google.protobuf.internal.encoder import _VarintBytes
from urllib3.util.retry import Retry
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format="%(message)s")

MIN_CHUNK_SIZE = 8 * 1024           # 8 KB minimum per Range request to batch headers
WANTED_TAGS    = {1, 2, 11, 12, 13} # GraphProto: node(1), name(2), input(11), output(12), value metadata(13)

class LoggingRetry(Retry):
    def increment(self, method, url, response=None, error=None, _pool=None, _stacktrace=None):
        new_retry = super().increment(method, url, response, error, _pool, _stacktrace)
        attempt = (self.total - new_retry.total + 1) if self.total is not None else 'unknown'
        delay = new_retry.get_backoff_time()
        err_str = f" Error: {error}" if error else ""
        sleep_str = f" Sleeping for {delay:.2f} seconds." if delay > 0 else ""
        logging.info(f"Retrying {method} request to {url}. Attempt {attempt}.{err_str}{sleep_str}")
        return new_retry

class RangeFetcher:
    def __init__(self, url, retries=10, backoff_factor=2, timeout=10, progress_bar=None):
        self.url = url
        self.session = requests.Session()
        self.buffer = bytearray()
        self.loaded_ranges = []   # list of (start, end), inclusive
        self.total_downloaded = 0
        self.timeout = timeout

        retry_strategy = LoggingRetry(
            total=retries,
            backoff_factor=backoff_factor,
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=["HEAD", "GET", "OPTIONS"]
        )
        retry_strategy.logger = logging.getLogger("urllib3.retry")
        retry_strategy.logger.setLevel(logging.INFO)
        adapter = HTTPAdapter(max_retries=retry_strategy)
        self.session.mount("http://", adapter)
        self.session.mount("https://", adapter)

        self.progress_bar = progress_bar if progress_bar is not None else tqdm(desc="Downloading model", unit="B", unit_scale=True)

    def _add_range(self, start, data):
        """Insert `data` at offset `start` into self.buffer, merge loaded_ranges."""
        end = start + len(data) - 1
        if end >= len(self.buffer):
            self.buffer.extend(b'\x00' * (end + 1 - len(self.buffer)))
        self.buffer[start:end+1] = data

        # merge intervals
        new = (start, end)
        merged = []
        i = 0
        while i < len(self.loaded_ranges) and self.loaded_ranges[i][1] < new[0] - 1:
            merged.append(self.loaded_ranges[i]); i += 1
        while i < len(self.loaded_ranges) and self.loaded_ranges[i][0] <= new[1] + 1:
            new = (min(new[0], self.loaded_ranges[i][0]),
                   max(new[1], self.loaded_ranges[i][1]))
            i += 1
        merged.append(new)
        merged.extend(self.loaded_ranges[i:])
        self.loaded_ranges = merged

    def fetch(self, start, end=None):
        """
        Ensure buffer[start..end] is loaded (inclusive).
        If end=None, do one fetch from `start` to EOF (no min-size).
        Otherwise, fetch only the missing sub-ranges, each >= MIN_CHUNK_SIZE.
        """
        if end is None:
            # open‐ended final fetch
            headers = {"Range": f"bytes={start}-"}
            self.progress_bar.set_description(f"Fetching bytes {start}-EOF")
            resp = self.session.get(self.url, headers=headers); resp.raise_for_status()
            chunk = resp.content
            self.total_downloaded += len(chunk)
            self.progress_bar.update(len(chunk))
            self._add_range(start, chunk)
            return self.buffer[start:start+len(chunk)]

        # find holes in [start..end]
        to_fetch = []
        cursor = start
        for (a, b) in self.loaded_ranges:
            if b < cursor: continue
            if a > end:   break
            if a > cursor:
                to_fetch.append((cursor, a-1))
            cursor = max(cursor, b+1)
        if cursor <= end:
            to_fetch.append((cursor, end))

        # fetch each hole (enforcing MIN_CHUNK_SIZE)
        for (s, e) in to_fetch:
            if (e - s + 1) < MIN_CHUNK_SIZE:
                e = s + MIN_CHUNK_SIZE - 1
            headers = {"Range": f"bytes={s}-{e}"}
            self.progress_bar.set_description(f"Fetching bytes {s}-{e}")
            resp = self.session.get(self.url, headers=headers); resp.raise_for_status()
            chunk = resp.content
            self.total_downloaded += len(chunk)
            self.progress_bar.update(len(chunk))
            self._add_range(s, chunk)

        return self.buffer[start:end+1]

    def close(self):
        """Clean up resources to prevent memory leaks."""
        self.session.close()
        self.progress_bar.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

# Based on https://raw.githubusercontent.com/onnx/onnx/refs/heads/main/onnx/onnx.proto

def read_varint(stream):
    """Read a Base-128 varint from `stream`. Returns (value, bytes_read)."""
    result = 0; shift = 0; count = 0
    while True:
        b = stream.read(1)
        if not b: raise EOFError
        b = b[0]
        result |= (b & 0x7F) << shift
        count += 1
        if not (b & 0x80): break
        shift += 7
    return result, count

def skip_field(stream, wire_type):
    """Advance `stream` past the next field of the given `wire_type`. Returns skipped bytes."""
    if wire_type == 0:       # varint
        _, n = read_varint(stream); return n
    elif wire_type == 1:     # fixed64
        stream.seek(8, io.SEEK_CUR); return 8
    elif wire_type == 2:     # length-delimited
        length, n = read_varint(stream)
        stream.seek(length, io.SEEK_CUR)
        return n + length
    elif wire_type == 5:     # fixed32
        stream.seek(4, io.SEEK_CUR); return 4
    else:
        raise ValueError(f"Unsupported wire type {wire_type}")

def locate_graph_header(fetcher, probe_size=1*1024*1024):
    """
    Download the first `probe_size` bytes, scan for ModelProto.graph (field 7, wire 2).
    Returns (before_bytes, graph_payload_offset, graph_length).
    """
    data = fetcher.fetch(0, probe_size-1)
    stream = io.BytesIO(data)
    before = bytearray()

    while True:
        key_pos = stream.tell()
        try:
            key, key_len = read_varint(stream)
        except EOFError:
            raise RuntimeError("Couldn't find graph header in initial bytes")
        field_num, wire_type = key >> 3, key & 0x7

        if field_num == 7 and wire_type == 2:
            length, length_len = read_varint(stream)
            header_len = key_len + length_len
            payload_off = key_pos + header_len
            return data[:key_pos], payload_off, length

        # skip and accumulate everything before
        skipped = skip_field(stream, wire_type)
        stream.seek(key_pos + key_len + skipped)
        before.extend(data[key_pos:key_pos + key_len + skipped])

def extract_graph_structure(fetcher, graph_off, graph_len):
    """
    Walk the GraphProto payload [graph_off..graph_off+graph_len) and
    collect only fields whose tag in WANTED_TAGS, skipping all others.
    Returns the concatenated bytes of just those wanted fields.
    """
    out = bytearray()
    pos = graph_off
    end = graph_off + graph_len

    while pos < end:
        # fetch minimal header (varint key + possible length prefix)
        hdr = fetcher.fetch(pos, pos + 19)
        hdr_stream = io.BytesIO(hdr)
        key, key_len = read_varint(hdr_stream)
        field_num, wire_type = key >> 3, key & 0x7

        length = 0; length_len = 0
        if wire_type == 2:
            length, length_len = read_varint(hdr_stream)

        # compute total bytes for this field
        if wire_type == 2:
            total = key_len + length_len + length
        elif wire_type == 0:
            total = key_len + skip_field(io.BytesIO(hdr[key_len:]), wire_type)
        elif wire_type == 1:
            total = key_len + 8
        elif wire_type == 5:
            total = key_len + 4
        else:
            raise ValueError(f"Bad wire type {wire_type}")

        # if it’s not one of the wanted tags, skip it entirely
        if field_num not in WANTED_TAGS:
            pos += total
            continue

        # otherwise fetch & append it in one go
        chunk = fetcher.fetch(pos, pos + total - 1)
        out.extend(chunk)
        pos += total

    return bytes(out)

def strip_data(graph, size_limit=1 * 1024 * 1024):
    # Remove initializers in the current graph.
    del graph.initializer[:]
    # Iterate over nodes to process any subgraphs.
    for node in graph.node:
        for attr in node.attribute:
            # If attribute holds a single subgraph.
            if attr.type == onnx.AttributeProto.GRAPH:
                strip_data(attr.g)
            # If attribute holds multiple subgraphs.
            elif attr.type == onnx.AttributeProto.GRAPHS:
                for subgraph in attr.graphs:
                    strip_data(subgraph)

        if node.op_type == "Constant":
            # ONNX Constant nodes store their tensor under attribute 'value'
            for attr in node.attribute:
                if attr.name == "value" and attr.t is not None:
                    tp = attr.t
                    data_size = len(tp.raw_data) if tp.raw_data else 0
                    if data_size > size_limit:
                        # Remove all data fields from TensorProto
                        tp.ClearField("raw_data")
                        tp.ClearField("float_data")
                        tp.ClearField("int32_data")
                        tp.ClearField("int64_data")
                        tp.ClearField("double_data")
                        tp.ClearField("uint64_data")
                        tp.ClearField("string_data")
                        # dims, data_type, and name are kept intact

def stream_parse_model_header(url):
    # Create a new nested progress bar (position=1 for enabling nested bars)
    progress = tqdm(desc="Downloading model", unit="B", unit_scale=True, position=1)
    with RangeFetcher(url, progress_bar=progress) as fetcher:
        # 1) find graph payload
        before, graph_off, graph_len = locate_graph_header(fetcher)
        # 2) stream‐parse only node/name/input/output fields
        graph_struct = extract_graph_structure(fetcher, graph_off, graph_len)
        # 3) fetch any ModelProto fields after graph (small headers, metadata)
        after = fetcher.fetch(graph_off + graph_len, None)
    # 4) rebuild minimal ModelProto
    tag = _VarintBytes((7 << 3) | 2)
    length = _VarintBytes(len(graph_struct))
    model_bytes = before + tag + length + graph_struct + after
    # 5) parse with ONNX
    model = onnx.ModelProto()
    model.ParseFromString(model_bytes)
    strip_data(model.graph)
    return model
