scripts/imagenet/utils.py (267 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
import time
from functools import partial
import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
from inplace_abn import ABN, InPlaceABN, InPlaceABNSync
from .transforms import ColorJitter, Lighting
def _get_norm_act(network_config):
if network_config["bn_mode"] == "standard":
assert network_config["activation"] in (
"relu",
"leaky_relu",
"elu",
"identity",
), "Standard batch normalization is only compatible with relu, leaky_relu, elu and identity"
activation_fn = partial(
ABN,
activation=network_config["activation"],
activation_param=network_config["activation_param"],
)
elif network_config["bn_mode"] == "inplace":
assert network_config["activation"] in (
"leaky_relu",
"elu",
"identity",
), "Inplace batch normalization is only compatible with leaky_relu, elu and identity"
activation_fn = partial(
InPlaceABN,
activation=network_config["activation"],
activation_param=network_config["activation_param"],
)
elif network_config["bn_mode"] == "sync":
assert network_config["activation"] in (
"leaky_relu",
"elu",
"identity",
), "Sync batch normalization is only compatible with leaky_relu, elu and identity"
activation_fn = partial(
InPlaceABNSync,
activation=network_config["activation"],
activation_param=network_config["activation_param"],
)
else:
print("Unrecognized batch normalization mode", network_config["bn_mode"])
exit(1)
return activation_fn
def get_model_params(network_config):
"""Convert a configuration to actual model parameters
Parameters
----------
network_config : dict
Dictionary containing the configuration options for the network.
Returns
-------
model_params : dict
Dictionary containing the actual parameters to be passed to the `net_*` functions in `models`.
"""
model_params = {}
if network_config["input_3x3"] and not network_config["arch"].startswith("wider"):
model_params["input_3x3"] = True
model_params["norm_act"] = _get_norm_act(network_config)
model_params["classes"] = network_config["classes"]
if not network_config["arch"].startswith("wider"):
model_params["dilation"] = network_config["dilation"]
return model_params
def create_optimizer(optimizer_config, model):
"""Creates optimizer and schedule from configuration
Parameters
----------
optimizer_config : dict
Dictionary containing the configuration options for the optimizer.
model : Model
The network model.
Returns
-------
optimizer : Optimizer
The optimizer.
scheduler : LRScheduler
The learning rate scheduler.
"""
if optimizer_config["classifier_lr"] != -1:
# Separate classifier parameters from all others
net_params = []
classifier_params = []
for k, v in model.named_parameters():
if k.find("fc") != -1:
classifier_params.append(v)
else:
net_params.append(v)
params = [
{"params": net_params},
{"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
]
else:
params = model.parameters()
if optimizer_config["type"] == "SGD":
optimizer = optim.SGD(
params,
lr=optimizer_config["learning_rate"],
momentum=optimizer_config["momentum"],
weight_decay=optimizer_config["weight_decay"],
nesterov=optimizer_config["nesterov"],
)
elif optimizer_config["type"] == "Adam":
optimizer = optim.Adam(
params,
lr=optimizer_config["learning_rate"],
weight_decay=optimizer_config["weight_decay"],
)
else:
raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
if optimizer_config["schedule"]["type"] == "step":
scheduler = lr_scheduler.StepLR(
optimizer, **optimizer_config["schedule"]["params"]
)
elif optimizer_config["schedule"]["type"] == "multistep":
scheduler = lr_scheduler.MultiStepLR(
optimizer, **optimizer_config["schedule"]["params"]
)
elif optimizer_config["schedule"]["type"] == "exponential":
scheduler = lr_scheduler.ExponentialLR(
optimizer, **optimizer_config["schedule"]["params"]
)
elif optimizer_config["schedule"]["type"] == "constant":
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
elif optimizer_config["schedule"]["type"] == "linear":
def linear_lr(it):
return (
it * optimizer_config["schedule"]["params"]["alpha"]
+ optimizer_config["schedule"]["params"]["beta"]
)
scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
return optimizer, scheduler
def create_transforms(input_config):
"""Create transforms from configuration
Parameters
----------
input_config : dict
Dictionary containing the configuration options for input pre-processing.
Returns
-------
train_transforms : list
List of transforms to be applied to the input during training.
val_transforms : list
List of transforms to be applied to the input during validation.
"""
normalize = transforms.Normalize(mean=input_config["mean"], std=input_config["std"])
train_transforms = []
if input_config["scale_train"] != -1:
train_transforms.append(transforms.Scale(input_config["scale_train"]))
train_transforms += [
transforms.RandomResizedCrop(input_config["crop_train"]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
if input_config["color_jitter_train"]:
train_transforms.append(ColorJitter())
if input_config["lighting_train"]:
train_transforms.append(Lighting())
train_transforms.append(normalize)
val_transforms = []
if input_config["scale_val"] != -1:
val_transforms.append(transforms.Resize(input_config["scale_val"]))
val_transforms += [
transforms.CenterCrop(input_config["crop_val"]),
transforms.ToTensor(),
normalize,
]
return train_transforms, val_transforms
def create_test_transforms(config, crop, scale, ten_crops):
normalize = transforms.Normalize(mean=config["mean"], std=config["std"])
val_transforms = []
if scale != -1:
val_transforms.append(transforms.Resize(scale))
if ten_crops:
val_transforms += [
transforms.TenCrop(crop),
transforms.Lambda(
lambda crops: [transforms.ToTensor()(crop) for crop in crops]
),
transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]),
transforms.Lambda(lambda crops: torch.stack(crops)),
]
else:
val_transforms += [
transforms.CenterCrop(crop),
transforms.ToTensor(),
normalize,
]
return val_transforms
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy_sum(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0))
return res
def validate(
val_loader,
model,
criterion,
ten_crops=False,
print_freq=1,
it=None,
tb=None,
logger=print,
):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
do_print = rank == 0
def process(input, target, all_reduce=None):
with torch.no_grad():
if ten_crops:
bs, ncrops, c, h, w = input.size()
input = input.view(-1, c, h, w)
target = target.cuda(non_blocking=True)
# compute output
if ten_crops:
output = model(input).view(bs, ncrops, -1).mean(1)
else:
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy_sum(output.data, target, topk=(1, 5))
loss *= target.shape[0]
count = target.new_tensor([target.shape[0]], dtype=torch.long)
if all_reduce:
all_reduce(count)
for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
if all_reduce:
all_reduce(val)
val /= count.item()
meter.update(val.item(), count.item())
# deal with remainder
all_reduce = (
partial(dist.all_reduce, op=dist.ReduceOp.SUM)
if dist.is_initialized()
else None
)
last_group_size = len(val_loader.dataset) % world_size
for i, (input, target) in enumerate(val_loader):
if input.shape[0] > 1 or last_group_size == 0:
process(input, target, all_reduce)
else:
process(
input,
target,
partial(
dist.all_reduce,
op=dist.ReduceOp.SUM,
group=dist.new_group(range(last_group_size)),
),
)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if do_print and i % print_freq == 0:
logger(
"Test: [{0}/{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t"
"Loss {loss.val:.4f} ({loss.avg:.4f}) \t"
"Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t"
"Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
i,
len(val_loader),
batch_time=batch_time,
loss=losses,
top1=top1,
top5=top5,
)
)
if input.shape[0] == 1 and rank > last_group_size > 0:
dist.new_group(range(last_group_size))
if do_print:
logger(
" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format(
top1=top1, top5=top5
)
)
if it is not None and (not dist.is_initialized() or dist.get_rank() == 0):
tb.add_scalar("val/loss", losses.avg, it)
tb.add_scalar("val/top1", top1.avg, it)
tb.add_scalar("val/top5", top5.avg, it)
return top1.avg