example/image-classification/train_imagenet.py (29 lines of code) (raw):

import os import argparse import logging logging.basicConfig(level=logging.DEBUG) from common import find_mxnet, data, fit from common.util import download_file import mxnet as mx if __name__ == '__main__': # parse args parser = argparse.ArgumentParser(description="train cifar10", formatter_class=argparse.ArgumentDefaultsHelpFormatter) fit.add_fit_args(parser) data.add_data_args(parser) data.add_data_aug_args(parser) # use a large aug level data.set_data_aug_level(parser, 3) parser.set_defaults( # network network = 'resnet', num_layers = 50, # data num_classes = 1000, num_examples = 1281167, image_shape = '3,224,224', min_random_scale = 1, # if input image has min size k, suggest to use # 256.0/x, e.g. 0.533 for 480 # train num_epochs = 80, lr_step_epochs = '30,60', ) args = parser.parse_args() # load network from importlib import import_module net = import_module('symbols.'+args.network) sym = net.get_symbol(**vars(args)) # train fit.fit(args, sym, data.get_rec_iter)