scripts/imagenet/utils.py (267 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. import time from functools import partial import torch import torch.distributed as dist import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torchvision.transforms as transforms from inplace_abn import ABN, InPlaceABN, InPlaceABNSync from .transforms import ColorJitter, Lighting def _get_norm_act(network_config): if network_config["bn_mode"] == "standard": assert network_config["activation"] in ( "relu", "leaky_relu", "elu", "identity", ), "Standard batch normalization is only compatible with relu, leaky_relu, elu and identity" activation_fn = partial( ABN, activation=network_config["activation"], activation_param=network_config["activation_param"], ) elif network_config["bn_mode"] == "inplace": assert network_config["activation"] in ( "leaky_relu", "elu", "identity", ), "Inplace batch normalization is only compatible with leaky_relu, elu and identity" activation_fn = partial( InPlaceABN, activation=network_config["activation"], activation_param=network_config["activation_param"], ) elif network_config["bn_mode"] == "sync": assert network_config["activation"] in ( "leaky_relu", "elu", "identity", ), "Sync batch normalization is only compatible with leaky_relu, elu and identity" activation_fn = partial( InPlaceABNSync, activation=network_config["activation"], activation_param=network_config["activation_param"], ) else: print("Unrecognized batch normalization mode", network_config["bn_mode"]) exit(1) return activation_fn def get_model_params(network_config): """Convert a configuration to actual model parameters Parameters ---------- network_config : dict Dictionary containing the configuration options for the network. Returns ------- model_params : dict Dictionary containing the actual parameters to be passed to the `net_*` functions in `models`. """ model_params = {} if network_config["input_3x3"] and not network_config["arch"].startswith("wider"): model_params["input_3x3"] = True model_params["norm_act"] = _get_norm_act(network_config) model_params["classes"] = network_config["classes"] if not network_config["arch"].startswith("wider"): model_params["dilation"] = network_config["dilation"] return model_params def create_optimizer(optimizer_config, model): """Creates optimizer and schedule from configuration Parameters ---------- optimizer_config : dict Dictionary containing the configuration options for the optimizer. model : Model The network model. Returns ------- optimizer : Optimizer The optimizer. scheduler : LRScheduler The learning rate scheduler. """ if optimizer_config["classifier_lr"] != -1: # Separate classifier parameters from all others net_params = [] classifier_params = [] for k, v in model.named_parameters(): if k.find("fc") != -1: classifier_params.append(v) else: net_params.append(v) params = [ {"params": net_params}, {"params": classifier_params, "lr": optimizer_config["classifier_lr"]}, ] else: params = model.parameters() if optimizer_config["type"] == "SGD": optimizer = optim.SGD( params, lr=optimizer_config["learning_rate"], momentum=optimizer_config["momentum"], weight_decay=optimizer_config["weight_decay"], nesterov=optimizer_config["nesterov"], ) elif optimizer_config["type"] == "Adam": optimizer = optim.Adam( params, lr=optimizer_config["learning_rate"], weight_decay=optimizer_config["weight_decay"], ) else: raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"])) if optimizer_config["schedule"]["type"] == "step": scheduler = lr_scheduler.StepLR( optimizer, **optimizer_config["schedule"]["params"] ) elif optimizer_config["schedule"]["type"] == "multistep": scheduler = lr_scheduler.MultiStepLR( optimizer, **optimizer_config["schedule"]["params"] ) elif optimizer_config["schedule"]["type"] == "exponential": scheduler = lr_scheduler.ExponentialLR( optimizer, **optimizer_config["schedule"]["params"] ) elif optimizer_config["schedule"]["type"] == "constant": scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0) elif optimizer_config["schedule"]["type"] == "linear": def linear_lr(it): return ( it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"] ) scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr) return optimizer, scheduler def create_transforms(input_config): """Create transforms from configuration Parameters ---------- input_config : dict Dictionary containing the configuration options for input pre-processing. Returns ------- train_transforms : list List of transforms to be applied to the input during training. val_transforms : list List of transforms to be applied to the input during validation. """ normalize = transforms.Normalize(mean=input_config["mean"], std=input_config["std"]) train_transforms = [] if input_config["scale_train"] != -1: train_transforms.append(transforms.Scale(input_config["scale_train"])) train_transforms += [ transforms.RandomResizedCrop(input_config["crop_train"]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] if input_config["color_jitter_train"]: train_transforms.append(ColorJitter()) if input_config["lighting_train"]: train_transforms.append(Lighting()) train_transforms.append(normalize) val_transforms = [] if input_config["scale_val"] != -1: val_transforms.append(transforms.Resize(input_config["scale_val"])) val_transforms += [ transforms.CenterCrop(input_config["crop_val"]), transforms.ToTensor(), normalize, ] return train_transforms, val_transforms def create_test_transforms(config, crop, scale, ten_crops): normalize = transforms.Normalize(mean=config["mean"], std=config["std"]) val_transforms = [] if scale != -1: val_transforms.append(transforms.Resize(scale)) if ten_crops: val_transforms += [ transforms.TenCrop(crop), transforms.Lambda( lambda crops: [transforms.ToTensor()(crop) for crop in crops] ), transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]), transforms.Lambda(lambda crops: torch.stack(crops)), ] else: val_transforms += [ transforms.CenterCrop(crop), transforms.ToTensor(), normalize, ] return val_transforms class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy_sum(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0)) return res def validate( val_loader, model, criterion, ten_crops=False, print_freq=1, it=None, tb=None, logger=print, ): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 do_print = rank == 0 def process(input, target, all_reduce=None): with torch.no_grad(): if ten_crops: bs, ncrops, c, h, w = input.size() input = input.view(-1, c, h, w) target = target.cuda(non_blocking=True) # compute output if ten_crops: output = model(input).view(bs, ncrops, -1).mean(1) else: output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy_sum(output.data, target, topk=(1, 5)) loss *= target.shape[0] count = target.new_tensor([target.shape[0]], dtype=torch.long) if all_reduce: all_reduce(count) for meter, val in (losses, loss), (top1, prec1), (top5, prec5): if all_reduce: all_reduce(val) val /= count.item() meter.update(val.item(), count.item()) # deal with remainder all_reduce = ( partial(dist.all_reduce, op=dist.ReduceOp.SUM) if dist.is_initialized() else None ) last_group_size = len(val_loader.dataset) % world_size for i, (input, target) in enumerate(val_loader): if input.shape[0] > 1 or last_group_size == 0: process(input, target, all_reduce) else: process( input, target, partial( dist.all_reduce, op=dist.ReduceOp.SUM, group=dist.new_group(range(last_group_size)), ), ) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if do_print and i % print_freq == 0: logger( "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t" "Loss {loss.val:.4f} ({loss.avg:.4f}) \t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t" "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5, ) ) if input.shape[0] == 1 and rank > last_group_size > 0: dist.new_group(range(last_group_size)) if do_print: logger( " * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format( top1=top1, top5=top5 ) ) if it is not None and (not dist.is_initialized() or dist.get_rank() == 0): tb.add_scalar("val/loss", losses.avg, it) tb.add_scalar("val/top1", top1.avg, it) tb.add_scalar("val/top5", top5.avg, it) return top1.avg