Dassl.pytorch/dassl/utils/tools.py (105 lines of code) (raw):

""" Modified from https://github.com/KaiyangZhou/deep-person-reid """ import os import sys import json import time import errno import numpy as np import random import os.path as osp import warnings from difflib import SequenceMatcher import PIL import torch from PIL import Image __all__ = [ "mkdir_if_missing", "check_isfile", "read_json", "write_json", "set_random_seed", "download_url", "read_image", "collect_env_info", "listdir_nohidden", "get_most_similar_str_to_a_from_b", "check_availability", "tolist_if_not", ] def mkdir_if_missing(dirname): """Create dirname if it is missing.""" if not osp.exists(dirname): try: os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise def check_isfile(fpath): """Check if the given path is a file. Args: fpath (str): file path. Returns: bool """ isfile = osp.isfile(fpath) if not isfile: warnings.warn('No file found at "{}"'.format(fpath)) return isfile def read_json(fpath): """Read json file from a path.""" with open(fpath, "r") as f: obj = json.load(f) return obj def write_json(obj, fpath): """Writes to a json file.""" mkdir_if_missing(osp.dirname(fpath)) with open(fpath, "w") as f: json.dump(obj, f, indent=4, separators=(",", ": ")) def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def download_url(url, dst): """Download file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024*duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write( "\r...%d%%, %d MB, %d KB/s, %d seconds passed" % (percent, progress_size / (1024*1024), speed, duration) ) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write("\n") def read_image(path): """Read image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ return Image.open(path).convert("RGB") def collect_env_info(): """Return env info as a string. Code source: github.com/facebookresearch/maskrcnn-benchmark """ from torch.utils.collect_env import get_pretty_env_info env_str = get_pretty_env_info() env_str += "\n Pillow ({})".format(PIL.__version__) return env_str def listdir_nohidden(path, sort=False): """List non-hidden items in a directory. Args: path (str): directory path. sort (bool): sort the items. """ items = [f for f in os.listdir(path) if not f.startswith(".")] if sort: items.sort() return items def get_most_similar_str_to_a_from_b(a, b): """Return the most similar string to a in b. Args: a (str): probe string. b (list): a list of candidate strings. """ highest_sim = 0 chosen = None for candidate in b: sim = SequenceMatcher(None, a, candidate).ratio() if sim >= highest_sim: highest_sim = sim chosen = candidate return chosen def check_availability(requested, available): """Check if an element is available in a list. Args: requested (str): probe string. available (list): a list of available strings. """ if requested not in available: psb_ans = get_most_similar_str_to_a_from_b(requested, available) raise ValueError( "The requested one is expected " "to belong to {}, but got [{}] " "(do you mean [{}]?)".format(available, requested, psb_ans) ) def tolist_if_not(x): """Convert to a list.""" if not isinstance(x, list): x = [x] return x