in vision/src/autogluon/vision/_gluoncv/image_classification.py [0:0]
def _train_image_classification(args, reporter):
"""
Parameters
----------
args: <class 'autogluon.utils.edict.EasyDict'>
"""
tic = time.time()
args = args.copy()
try:
task_id = int(args['task_id'])
except:
task_id = 0
problem_type = args.pop('problem_type', MULTICLASS)
final_fit = args.pop('final_fit', False)
# train, val data
train_data = args.pop('train_data')
val_data = args.pop('val_data')
# wall clock tick limit
wall_clock_tick = args.pop('wall_clock_tick')
log_dir = args.pop('log_dir', os.getcwd())
# exponential batch size for Int() space batch sizes
exp_batch_size = args.pop('exp_batch_size', False)
if exp_batch_size and 'batch_size' in args:
args['batch_size'] = 2 ** args['batch_size']
try:
task = args.pop('task')
dataset = args.pop('dataset')
num_trials = args.pop('num_trials')
except KeyError:
task = None
# mxnet and torch dispatcher
dispatcher = None
torch_model_list = None
mxnet_model_list = None
custom_net = None
if args.get('custom_net', None):
custom_net = args.get('custom_net')
if torch and timm:
if isinstance(custom_net, torch.nn.Module):
dispatcher = 'torch'
if mx:
if isinstance(custom_net, mx.gluon.Block):
dispatcher = 'mxnet'
else:
if torch and timm:
torch_model_list = timm.list_models()
if mx:
mxnet_model_list = list(get_model_list())
model = args.get('model', None)
if model:
# timm model has higher priority
if torch_model_list and model in torch_model_list:
dispatcher = 'torch'
elif mxnet_model_list and model in mxnet_model_list:
dispatcher = 'mxnet'
else:
if not torch_model_list:
raise ValueError('Model not found in gluoncv model zoo. Install torch and timm if it supports the model.')
elif not mxnet_model_list:
raise ValueError('Model not found in timm model zoo. Install mxnet if it supports the model.')
else:
raise ValueError('Model not supported because it does not exist in both timm and gluoncv model zoo.')
assert dispatcher in ('torch', 'mxnet'), 'custom net needs to be of type either torch.nn.Module or mx.gluon.Block'
args['estimator'] = TorchImageClassificationEstimator if dispatcher=='torch' else ImageClassificationEstimator
# convert user defined config to nested form
args = config_to_nested(args)
if wall_clock_tick < tic and not final_fit:
return {'traceback': 'timeout', 'args': str(args),
'time': 0, 'train_acc': -1, 'valid_acc': -1}
try:
valid_summary_file = 'fit_summary_img_cls.ag'
estimator_cls = args.pop('estimator', None)
assert estimator_cls in (ImageClassificationEstimator, TorchImageClassificationEstimator)
if final_fit:
# load from previous dumps
estimator = None
if os.path.isdir(log_dir):
is_valid_dir_fn = lambda d : d.startswith('.trial_') and os.path.isdir(os.path.join(log_dir, d))
trial_dirs = [d for d in os.listdir(log_dir) if is_valid_dir_fn(d)]
best_checkpoint = ''
best_acc = -1
result = {}
for dd in trial_dirs:
try:
with open(os.path.join(log_dir, dd, valid_summary_file), 'r') as f:
result = json.load(f)
acc = result.get('valid_acc', -1)
if acc > best_acc and os.path.isfile(os.path.join(log_dir, dd, _BEST_CHECKPOINT_FILE)):
best_checkpoint = os.path.join(log_dir, dd, _BEST_CHECKPOINT_FILE)
best_acc = acc
except:
pass
if best_checkpoint:
estimator = estimator_cls.load(best_checkpoint)
if estimator is None:
if wall_clock_tick < tic:
result.update({'traceback': 'timeout'})
else:
# unknown error yet, try reproduce it
final_fit = False
if not final_fit:
# create independent log_dir for each trial
trial_log_dir = os.path.join(log_dir, '.trial_{}'.format(task_id))
args['log_dir'] = trial_log_dir
custom_optimizer = args.pop('custom_optimizer', None)
estimator = estimator_cls(args, problem_type=problem_type, reporter=reporter,
net=custom_net, optimizer=custom_optimizer)
# training
result = estimator.fit(train_data=train_data, val_data=val_data, time_limit=wall_clock_tick-tic)
with open(os.path.join(trial_log_dir, valid_summary_file), 'w') as f:
json.dump(result, f)
# save config and result
if task is not None:
trial_log = {}
trial_log.update(args)
trial_log.update(result)
json_str = json.dumps(trial_log)
time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
json_file_name = task + '_dataset-' + dataset + '_trials-' + str(num_trials) + '_' + time_str + '.json'
with open(json_file_name, 'w') as json_file:
json_file.write(json_str)
logging.info('Config and result in this trial have been saved to %s.', json_file_name)
except:
import traceback
return {'traceback': traceback.format_exc(), 'args': str(args),
'time': time.time() - tic, 'train_acc': -1, 'valid_acc': -1}
if estimator:
result.update({'model_checkpoint': estimator})
result.update({'estimator': estimator_cls})
return result