in containers/Shoot/CNN/train.py [0:0]
def get_data(batch_size, dtype, host_ctx):
data_dir = os.environ["SM_CHANNEL_TRAIN"]
data_file = os.listdir(data_dir)[0]
layouts = []
with open(data_dir + '/' + data_file) as f:
for x in f:
record = json.loads(x)
layouts.append(create_board(record, "TeamA", dtype))
layouts.append(create_board(record, "TeamB", dtype))
boards = [item for sublist in layouts for item in sublist]
random.shuffle(boards)
split = int(len(boards) * .7)
with host_ctx:
d_t = nd.concat(*[x[0].expand_dims(0) for x in boards[:split]], dim=0)
l_t = nd.concat(*[x[1].expand_dims(0) for x in boards[:split]], dim=0)
m_t = nd.concat(*[x[2].expand_dims(0) for x in boards[:split]], dim=0)
d_v = nd.concat(*[x[0].expand_dims(0) for x in boards[split:]], dim=0)
l_v = nd.concat(*[x[1].expand_dims(0) for x in boards[split:]], dim=0)
m_v = nd.concat(*[x[2].expand_dims(0) for x in boards[split:]], dim=0)
return (mx.io.NDArrayIter(
data=d_t,
label=[l_t, m_t],
shuffle=True,
batch_size=batch_size,
last_batch_handle="pad"
),
mx.io.NDArrayIter(
data=d_v,
label=[l_v, m_v],
shuffle=False,
last_batch_handle="pad",
batch_size=2 * batch_size,
)
)