utils/utils.py (90 lines of code) (raw):

import os import os.path as osp import numpy as np import matplotlib.pyplot as plt import random import torch import torch.nn as nn import torch.nn.functional as F def create_dir(_path): if not os.path.exists(_path): os.makedirs(_path) def set_random_seed(seed=None): if seed is not None: # raise NotImplementedError('Fixing seed has not yet been implemented.') print('Set random seed as', seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe # TODO: this leads to performance degradation, while the result can not be repeated # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.deterministic = True os.environ['PYTHONHASHSEED'] = str(seed) torch.backends.cudnn.benchmark = True def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res class AverageMeter: """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.values = [] self.counter = 0 def append(self, val): self.values.append(val) self.counter += 1 @property def val(self): return self.values[-1] @property def avg(self): return sum(self.values) / len(self.values) @property def sum(self): return sum(self.values) @property def last_avg(self): if self.counter == 0: return self.latest_avg else: self.latest_avg = sum(self.values[-self.counter:]) / self.counter self.counter = 0 return self.latest_avg class TwoCropTransform: """Create two crops of the same image""" def __init__(self, transform): self.transform = transform def __call__(self, x): return [self.transform(x), self.transform(x)] def save_curve(args, save_dir, training_losses, test_clean_losses, overall_accs, many_accs, median_accs, low_accs, f1s): plt.figure() plt.plot(training_losses, 'b', label='training_losses') plt.plot(test_clean_losses, 'g', label='test_clean_losses') plt.grid() plt.legend() plt.savefig(osp.join(save_dir, 'losses.png')) plt.close() plt.plot(overall_accs, 'm', label='overall_accs') if args.imbalance_ratio < 1: plt.plot(many_accs, 'r', label='many_accs') plt.plot(median_accs, 'g', label='median_accs') plt.plot(low_accs, 'b', label='low_accs') plt.grid() plt.legend() plt.savefig(osp.join(save_dir, 'test_accs.png')) plt.close() plt.plot(f1s, 'm', label='f1s') plt.grid() plt.legend() plt.savefig(osp.join(save_dir, 'test_f1s.png')) plt.close() def is_parallel(model): # Returns True if model is of type DP or DDP return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) def de_parallel(model): # De-parallelize a model: returns single-GPU model if model is of type DP or DDP return model.module if is_parallel(model) else model