def _train_image_classification()

in vision/src/autogluon/vision/_gluoncv/image_classification.py [0:0]


def _train_image_classification(args, reporter):
    """
    Parameters
    ----------
    args: <class 'autogluon.utils.edict.EasyDict'>
    """
    tic = time.time()
    args = args.copy()
    try:
        task_id = int(args['task_id'])
    except:
        task_id = 0
    problem_type = args.pop('problem_type', MULTICLASS)
    final_fit = args.pop('final_fit', False)
    # train, val data
    train_data = args.pop('train_data')
    val_data = args.pop('val_data')
    # wall clock tick limit
    wall_clock_tick = args.pop('wall_clock_tick')
    log_dir = args.pop('log_dir', os.getcwd())
    # exponential batch size for Int() space batch sizes
    exp_batch_size = args.pop('exp_batch_size', False)
    if exp_batch_size and 'batch_size' in args:
        args['batch_size'] = 2 ** args['batch_size']
    try:
        task = args.pop('task')
        dataset = args.pop('dataset')
        num_trials = args.pop('num_trials')
    except KeyError:
        task = None

    # mxnet and torch dispatcher
    dispatcher = None
    torch_model_list = None
    mxnet_model_list = None
    custom_net = None
    if args.get('custom_net', None):
        custom_net = args.get('custom_net')
        if torch and timm:
            if isinstance(custom_net, torch.nn.Module):
                dispatcher = 'torch'
        if mx:
            if isinstance(custom_net, mx.gluon.Block):
                dispatcher = 'mxnet'
    else:
        if torch and timm:
            torch_model_list = timm.list_models()
        if mx:
            mxnet_model_list = list(get_model_list())
        model = args.get('model', None)
        if model:
            # timm model has higher priority
            if torch_model_list and model in torch_model_list:
                dispatcher = 'torch'
            elif mxnet_model_list and model in mxnet_model_list:
                dispatcher = 'mxnet'
            else:
                if not torch_model_list:
                    raise ValueError('Model not found in gluoncv model zoo. Install torch and timm if it supports the model.')
                elif not mxnet_model_list:
                    raise ValueError('Model not found in timm model zoo. Install mxnet if it supports the model.')
                else:
                    raise ValueError('Model not supported because it does not exist in both timm and gluoncv model zoo.')
    assert dispatcher in ('torch', 'mxnet'), 'custom net needs to be of type either torch.nn.Module or mx.gluon.Block'
    args['estimator'] = TorchImageClassificationEstimator if dispatcher=='torch' else ImageClassificationEstimator
    # convert user defined config to nested form
    args = config_to_nested(args)

    if wall_clock_tick < tic and not final_fit:
        return {'traceback': 'timeout', 'args': str(args),
                'time': 0, 'train_acc': -1, 'valid_acc': -1}

    try:
        valid_summary_file = 'fit_summary_img_cls.ag'
        estimator_cls = args.pop('estimator', None)
        assert estimator_cls in (ImageClassificationEstimator, TorchImageClassificationEstimator)
        if final_fit:
            # load from previous dumps
            estimator = None
            if os.path.isdir(log_dir):
                is_valid_dir_fn = lambda d : d.startswith('.trial_') and os.path.isdir(os.path.join(log_dir, d))
                trial_dirs = [d for d in os.listdir(log_dir) if is_valid_dir_fn(d)]
                best_checkpoint = ''
                best_acc = -1
                result = {}
                for dd in trial_dirs:
                    try:
                        with open(os.path.join(log_dir, dd, valid_summary_file), 'r') as f:
                            result = json.load(f)
                            acc = result.get('valid_acc', -1)
                            if acc > best_acc and os.path.isfile(os.path.join(log_dir, dd, _BEST_CHECKPOINT_FILE)):
                                best_checkpoint = os.path.join(log_dir, dd, _BEST_CHECKPOINT_FILE)
                                best_acc = acc
                    except:
                        pass
                if best_checkpoint:
                    estimator = estimator_cls.load(best_checkpoint)
            if estimator is None:
                if wall_clock_tick < tic:
                    result.update({'traceback': 'timeout'})
                else:
                    # unknown error yet, try reproduce it
                    final_fit = False
        if not final_fit:
            # create independent log_dir for each trial
            trial_log_dir = os.path.join(log_dir, '.trial_{}'.format(task_id))
            args['log_dir'] = trial_log_dir
            custom_optimizer = args.pop('custom_optimizer', None)
            estimator = estimator_cls(args, problem_type=problem_type, reporter=reporter,
                                      net=custom_net, optimizer=custom_optimizer)
            # training
            result = estimator.fit(train_data=train_data, val_data=val_data, time_limit=wall_clock_tick-tic)
            with open(os.path.join(trial_log_dir, valid_summary_file), 'w') as f:
                json.dump(result, f)
            # save config and result
            if task is not None:
                trial_log = {}
                trial_log.update(args)
                trial_log.update(result)
                json_str = json.dumps(trial_log)
                time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                json_file_name = task + '_dataset-' + dataset + '_trials-' + str(num_trials) + '_' + time_str + '.json'
                with open(json_file_name, 'w') as json_file:
                    json_file.write(json_str)
                logging.info('Config and result in this trial have been saved to %s.', json_file_name)
    except:
        import traceback
        return {'traceback': traceback.format_exc(), 'args': str(args),
                'time': time.time() - tic, 'train_acc': -1, 'valid_acc': -1}

    if estimator:
        result.update({'model_checkpoint': estimator})
        result.update({'estimator': estimator_cls})
    return result