in validate.py [0:0]
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(
pretrained=True,
exclude_filters=_NON_IN1K_FILTERS,
)
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(
args.model,
pretrained=True,
)
model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs):
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
initial_batch_size = args.batch_size
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
r = _try_run(args, initial_batch_size)
if 'error' in r:
continue
if args.checkpoint:
r['checkpoint'] = args.checkpoint
results.append(r)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
else:
if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')