def get_train_data()

in notebooks/classify_mxnet.py [0:0]


def get_train_data(data_dir, batch_size):
    train_imgs = gluon.data.vision.ImageRecordDataset(os.path.join(data_dir, 'train_rec.rec'))
                                                     
    normalize = gluon.data.vision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_augs = gluon.data.vision.transforms.Compose([
        gluon.data.vision.transforms.RandomResizedCrop(224),
        gluon.data.vision.transforms.RandomFlipLeftRight(),
        gluon.data.vision.transforms.RandomFlipTopBottom(),
        gluon.data.vision.transforms.ToTensor(),
        normalize])
    
    train_iter = gluon.data.DataLoader(
        train_imgs.transform_first(train_augs), batch_size, shuffle=True, last_batch='rollover')
    
    return train_iter