vision/m4/sourcing/pmd/helpers.py (190 lines of code) (raw):

import json import socket import ssl import urllib.error import urllib.parse import urllib.request import urllib.robotparser from dataclasses import fields from datetime import datetime from functools import lru_cache from multiprocessing.pool import ThreadPool from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image from datasets import Dataset from datasets.utils.file_utils import DownloadConfig from m4.sourcing.pmd.cache_path import cached_path def json_serializer(o): if isinstance(o, BaseException): return repr(o) if isinstance(o, datetime): return str(o) raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") def collapse_columns_to_meta( batch: Dict, columns_to_collapse: List[str], meta_column_name: str, ) -> Dict: # Order matters assert isinstance(columns_to_collapse, list) # {meta_column_name} needs to :be either inserted in # - not be a column_name of the dataset in question # - part of the columns we're going to collapse. assert meta_column_name not in batch or meta_column_name in columns_to_collapse # Aggregate all values into a single dict metas = [ json.dumps( {column_name: value for column_name, value in zip(columns_to_collapse, values)}, default=json_serializer, indent=2, ) for values in zip(*[batch[column_name] for column_name in columns_to_collapse]) ] # Remove columns from batch for column_name in columns_to_collapse: del batch[column_name] batch[meta_column_name] = metas return batch # ---- Download medias helper ---- class RobotsDisallow(BaseException): """Exception class when robots.txt prevents us from downloading the urls""" pass def compute_cache_path(cache_dir, hash): cached_path = Path(cache_dir) / hash[:3] / hash[3:6] / hash cached_path.parent.mkdir(parents=True, exist_ok=True) return str(cached_path.absolute()) # LRU cache that caches exception as well ... let's see if this works. def lru_cache_with_exception(maxsize=128, typed=False): _lru_cache = lru_cache(maxsize, typed) def wrap_exception(unwrapped_function): def func(*args, **kwargs): try: return unwrapped_function(*args, **kwargs), None except BaseException as err: return None, err return func def unwrap_exception(wrapped_function): def func(*args, **kwargs): result, exception = wrapped_function(*args, **kwargs) if exception is not None: raise exception else: return result return func def decorating_function(user_function): return unwrap_exception(_lru_cache(wrap_exception(user_function))) return decorating_function # https://stackoverflow.com/a/28052583 ssl._create_default_https_context = ssl._create_unverified_context class M4HTTPClient: def __init__(self, cache_dir: Path, retries: int, offline_mode: bool, user_agent: Optional[str] = None): super(M4HTTPClient, self).__init__() # Hack found: https://github.com/igorbrigadir/DownloadConceptualCaptions/blob/efb16f028936e6c628b6ee435765d6e1771b0f2d/download_data.py#L13 assert user_agent in ["Googlebot-Image/1.0", "Googlebot-Video/1.0", None] self.user_agent = user_agent self.datasets_download_config = DownloadConfig( cache_dir=cache_dir, user_agent=self.user_agent, num_proc=1, # We handle this via `.map` max_retries=retries, # TODO @thomasw21 Not really sure we care about versioning ... use_etag=False, ) self.offline_mode = offline_mode def check_robots_txt(self, url): parsed_url = urllib.parse.urlparse(url) robots_txt_url = f"{parsed_url.scheme}://{parsed_url.netloc}/robots.txt" robots_parser = self.__get_robots_txt__(robots_txt_url) return robots_parser.can_fetch(self.user_agent, url) # TODO @thomasw21: Maybe lru if we're scared of it being too big at some point. @lru_cache_with_exception(maxsize=None) def __get_robots_txt__(self, robots_txt_url): robots_parser = urllib.robotparser.RobotFileParser(robots_txt_url) # equivalent to `robots_parser.read()` but with a timeout try: f = urllib.request.urlopen(robots_parser.url, timeout=10) except urllib.error.HTTPError as err: if err.code in (401, 403): # robots.txt could not be queried to check robots_parser.allow_all = True elif err.code >= 400 and err.code < 500: robots_parser.allow_all = True except urllib.error.URLError as err: if isinstance(err.reason, socket.timeout): # We couldn't find robots.txt, we assume we can query the media. robots_parser.allow_all = True else: # unknown exception # print(robots_txt_url, err) raise err else: raw = f.read() robots_parser.parse(raw.decode("utf-8").splitlines()) return robots_parser def cache_path(self, url: str) -> Union[str, BaseException]: try: # Try querying the file locally first try: return cached_path( url, compute_cache_path=compute_cache_path, **{ field.name: getattr(self.datasets_download_config, field.name) for field in fields(self.datasets_download_config) }, local_files_only=True, ) except FileNotFoundError as e: # We ignore the exception when the file could not be found in the cache. In offline mode, we return the exception if self.offline_mode: return e except BaseException as e: # We ignore this exception as this will be caught down the line. However this should never happen ... if self.offline_mode: # In offline mode, we raise exception as it's something that shouldn't happen raise e # check if robots.txt allows us to download the the url if not self.check_robots_txt(url): return RobotsDisallow("Unable to query the url due to `robots.txt` restrictions") # Return file path or exception return cached_path( url, compute_cache_path=compute_cache_path, download_config=self.datasets_download_config ) except BaseException as e: return e def fetch_single_image( image_url: str, http_client: M4HTTPClient ) -> Union[Tuple[str, None], Tuple[None, BaseException]]: path_or_exception = http_client.cache_path(image_url) if isinstance(path_or_exception, str): path = path_or_exception try: # Check that it's an image with PIL.Image.open(path) as image: image.verify() return path, None except BaseException as exception: return None, exception else: exception = path_or_exception return None, exception def batch_iter(dset: Dataset, transform: Callable[[Dict], List[Any]], batch_size: int = 1000): num_rows = len(dset) for start in range(0, num_rows, batch_size): batch = dset[start : start + batch_size] for elt in transform(batch): yield elt class PickableMediaDownloadGenerator: def __init__( self, download_media_url: Callable[[str], Union[Tuple[str, None], Tuple[None, BaseException]]], get_media_urls: Callable[[Dict], List[str]], dset: Dataset, batch_size: int, num_threads_per_proc: int, ): self.download_media_url = download_media_url self.get_media_urls = get_media_urls self.dset = dset self.batch_size = batch_size self.num_threads_per_proc = num_threads_per_proc # This is used to trick the pickle algorithm as we load thread_pool and _media_iterator AFTER fingerprinting self._has_started = False self._thread_pool = None self._media_iterator = None def load_media_iterator(self): self._thread_pool = ThreadPool(self.num_threads_per_proc) self._media_iterator = self._thread_pool.imap( self.download_media_url, iterable=batch_iter(dset=self.dset, transform=self.get_media_urls, batch_size=self.batch_size), ) def __iter__(self): return self def __next__(self): if self._has_started is False: self.load_media_iterator() self._has_started = True return next(self._media_iterator) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self._thread_pool is not None: self._has_started = False self._thread_pool.terminate() self._thread_pool.join() del self._thread_pool