in examples/mxnet_imagenet_resnet50.py [0:0]
def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size,
data_nthreads):
rec_train = os.path.expanduser(rec_train)
rec_train_idx = os.path.expanduser(rec_train_idx)
rec_val = os.path.expanduser(rec_val)
rec_val_idx = os.path.expanduser(rec_val_idx)
jitter_param = 0.4
lighting_param = 0.1
mean_rgb = [123.68, 116.779, 103.939]
train_iter = mx.io.ImageRecordIter(
path_imgrec=rec_train,
path_imgidx=rec_train_idx,
preprocess_threads=data_nthreads,
shuffle=True,
batch_size=batch_size,
label_width=1,
data_shape=(3, 224, 224),
mean_r=mean_rgb[0],
mean_g=mean_rgb[1],
mean_b=mean_rgb[2],
rand_mirror=True,
rand_crop=False,
random_resized_crop=True,
max_aspect_ratio=4. / 3.,
min_aspect_ratio=3. / 4.,
max_random_area=1,
min_random_area=0.08,
verbose=False,
brightness=jitter_param,
saturation=jitter_param,
contrast=jitter_param,
pca_noise=lighting_param,
num_parts=num_workers,
part_index=rank,
device_id=local_rank
)
# Kept each node to use full val data to make it easy to monitor results
val_iter = mx.io.ImageRecordIter(
path_imgrec=rec_val,
path_imgidx=rec_val_idx,
preprocess_threads=data_nthreads,
shuffle=False,
batch_size=batch_size,
resize=256,
label_width=1,
rand_crop=False,
rand_mirror=False,
data_shape=(3, 224, 224),
mean_r=mean_rgb[0],
mean_g=mean_rgb[1],
mean_b=mean_rgb[2],
device_id=local_rank
)
return train_iter, val_iter