in validate.py [0:0]
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
if args.reparam:
model = reparameterize_model(model)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True,
)
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
model = torch.jit.script(model)
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1')
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset = create_dataset(
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
num_samples=args.num_samples,
input_key=args.input_key,
input_img_mode=input_img_mode,
target_key=args.target_key,
trust_remote_code=args.dataset_trust_remote_code,
)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = [int(line.rstrip()) for line in f]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
if args.naflex_loader:
from timm.data import create_naflex_loader
loader = create_naflex_loader(
dataset,
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype or torch.float32,
patch_size=16, # Could be derived from model config
max_seq_len=args.naflex_max_seq_len,
)
else:
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype or torch.float32,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
if not args.naflex_loader:
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device=device)
input = input.to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
batch_size = output.shape[0]
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), batch_size)
top1.update(acc1.item(), batch_size)
top5.update(acc5.item(), batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=batch_size / batch_time.avg,
loss=losses,
top1=top1,
top5=top5
)
)
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
results = OrderedDict(
model=args.model,
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'],
)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results