neuron_explainer/file_utils.py (63 lines of code) (raw):

import io import os import urllib.request from io import IOBase import aiohttp def file_exists(filepath: str) -> bool: if filepath.startswith("https://"): try: urllib.request.urlopen(filepath) return True except urllib.error.HTTPError: return False else: # It's a local file. return os.path.exists(filepath) class CustomFileHandler: def __init__(self, filepath: str, mode: str) -> None: self.filepath = filepath self.mode = mode self.file = None def __enter__(self) -> IOBase: assert not self.filepath.startswith("az://"), "Azure blob storage is not supported" if self.filepath.startswith("https://"): assert self.mode in ["r", "rb"], "Only read mode is supported for remote files" remote_data = urllib.request.urlopen(self.filepath) if "b" in self.mode: # Read the content into a BytesIO object for binary mode self.file = io.BytesIO(remote_data.read()) else: # Decode the content and use StringIO for text mode (less common for torch.load) self.file = io.StringIO(remote_data.read().decode()) else: # Create the subdirectories if they don't exist directory = os.path.dirname(self.filepath) os.makedirs(directory, exist_ok=True) self.file = open(self.filepath, self.mode) if "b" in self.mode: # Ensure the file is seekable; if not, read into a BytesIO object try: self.file.seek(0) except io.UnsupportedOperation: self.file.close() with open(self.filepath, self.mode) as f: self.file = io.BytesIO(f.read()) return self.file def __exit__(self, exc_type, exc_val, exc_tb) -> bool: # Close the file if it's open if self.file is not None: self.file.close() # Propagate exceptions return False async def read_single_async(filepath: str) -> bytes: if filepath.startswith("https://"): async with aiohttp.ClientSession() as session: async with session.get(filepath) as response: return await response.read() else: with open(filepath, "rb") as f: return f.read() def copy_to_local_cache(src: str, dst: str) -> None: if not os.path.exists(os.path.dirname(dst)): os.makedirs(os.path.dirname(dst), exist_ok=True) if src.startswith("https://"): with urllib.request.urlopen(src) as response, open(dst, "wb") as out_file: data = response.read() # Consider chunked reading for large files out_file.write(data) else: with open(src, "rb") as in_file, open(dst, "wb") as out_file: data = in_file.read() out_file.write(data)