in bulk_runner.py [0:0]
def main():
args = parser.parse_args()
cmd, cmd_args = cmd_from_args(args)
model_cfgs = []
if args.model_list == 'all':
model_names = list_models(
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
)
model_cfgs = [(n, None) for n in model_names]
elif args.model_list == 'all_in1k':
model_names = list_models(pretrained=True)
model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True)
elif args.model_list == 'all_res':
model_names = list_models()
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False, expand_arch=True)
elif not is_model(args.model_list):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model_list)
model_cfgs = [(n, None) for n in model_names]
if not model_cfgs and os.path.exists(args.model_list):
with open(args.model_list) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = _get_model_cfgs(
model_names,
#num_classes=1000,
expand_train_test=True,
#include_crop=False,
)
if len(model_cfgs):
results_file = args.results_file or './results.csv'
results = []
errors = []
model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
print(f"Running script on these models:\n {model_strings}")
if not args.sort_key:
if 'benchmark' in args.script:
if any(['train' in a for a in args.script_args]):
sort_key = 'train_samples_per_sec'
else:
sort_key = 'infer_samples_per_sec'
else:
sort_key = 'top1'
else:
sort_key = args.sort_key
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
try:
for m, ax in model_cfgs:
if not m:
continue
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
if ax is not None:
extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
extra_args = [i for t in extra_args for i in t]
args_str += tuple(extra_args)
try:
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
r = json.loads(o)
results.append(r)
except Exception as e:
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
# for further robustness (but more overhead), we may want to manage that by looping here...
errors.append(dict(model=m, error=str(e)))
if args.delay:
time.sleep(args.delay)
except KeyboardInterrupt as e:
pass
errors.extend(list(filter(lambda x: 'error' in x, results)))
if errors:
print(f'{len(errors)} models had errors during run.')
for e in errors:
if 'model' in e:
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
else:
print(e)
results = list(filter(lambda x: 'error' not in x, results))
no_sortkey = list(filter(lambda x: sort_key not in x, results))
if no_sortkey:
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
else:
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
write_results(results_file, results)