torchbenchmark/util/framework/timm/instantiate.py (162 lines of code) (raw):
""" Hacked from https://github.com/rwightman/pytorch-image-models/blob/f7d210d759beb00a3d0834a3ce2d93f6e17f3d38/train.py
ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
from contextlib import suppress
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.models import create_model, convert_splitbn_model
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import NativeScaler
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, LabelSmoothingCrossEntropy
from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from .loader import create_fake_imagenet_dataset
def timm_instantiate_eval(args):
# create eval model
eval_model = create_model(
args.model_name,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
data_config = resolve_data_config(vars(args), model=eval_model, use_test_size=True, verbose=True)
eval_model = eval_model.to(args.device)
# enable channels last layout if set
if args.channels_last:
eval_model = eval_model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
eval_model = torch.nn.DataParallel(eval_model, device_ids=list(range(args.num_gpu)))
crop_pct = data_config['crop_pct']
# create dataset
dataset_eval = create_fake_imagenet_dataset(size=args.eval_num_batch*args.eval_batch_size)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.eval_batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing,
persistent_workers=False,
)
return eval_model, loader_eval
def timm_instantiate_train(args):
# create train model
model = create_model(
args.model_name,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
model = model.to(args.device)
# enable channels last layout if set
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
assert not args.split_bn
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
print(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
# setup optimizer
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if args.use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
# setup distributed training
if args.distributed:
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, _ = create_scheduler(args, optimizer)
# create fake imagenet dataset
fake_dataset = create_fake_imagenet_dataset(size=args.batch_size * args.train_num_batch)
dataset_train = fake_dataset
dataset_eval = fake_dataset
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
# Not supported by timm 0.4.12
# num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader,
# Not supported by timm 0.4.12
# worker_seeding=args.worker_seeding,
persistent_workers=False,
)
loader_validate = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
persistent_workers=False,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# NOTE: the latest timm package (0.4.12) doesn't support BinaryCrossEntropy
# smoothing is handled with mixup target transform which outputs sparse, soft targets
# if args.bce_loss:
# train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
# else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
# if args.bce_loss:
# train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
# else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(args.device)
validate_loss_fn = nn.CrossEntropyLoss().to(args.device)
# return all the inputs needed by train and eval loop
return model, loader_train, loader_validate, optimizer, \
train_loss_fn, lr_scheduler, amp_autocast, \
loss_scaler, mixup_fn, validate_loss_fn