def get_data()

in containers/Shoot/CNN/train.py [0:0]


def get_data(batch_size, dtype, host_ctx):
    data_dir = os.environ["SM_CHANNEL_TRAIN"]
    data_file = os.listdir(data_dir)[0]
    layouts = []
    with open(data_dir + '/' + data_file) as f:
        for x in f:
            record = json.loads(x)
            layouts.append(create_board(record, "TeamA", dtype))
            layouts.append(create_board(record, "TeamB", dtype))

    boards = [item for sublist in layouts for item in sublist]
    random.shuffle(boards)
    split = int(len(boards) * .7)
    with host_ctx:
        d_t = nd.concat(*[x[0].expand_dims(0) for x in boards[:split]], dim=0)
        l_t = nd.concat(*[x[1].expand_dims(0) for x in boards[:split]], dim=0)
        m_t = nd.concat(*[x[2].expand_dims(0) for x in boards[:split]], dim=0)

        d_v = nd.concat(*[x[0].expand_dims(0) for x in boards[split:]], dim=0)
        l_v = nd.concat(*[x[1].expand_dims(0) for x in boards[split:]], dim=0)
        m_v = nd.concat(*[x[2].expand_dims(0) for x in boards[split:]], dim=0)

    return (mx.io.NDArrayIter(
        data=d_t,
        label=[l_t, m_t],
        shuffle=True,
        batch_size=batch_size,
        last_batch_handle="pad"
    ),
            mx.io.NDArrayIter(
                data=d_v,
                label=[l_v, m_v],
                shuffle=False,
                last_batch_handle="pad",
                batch_size=2 * batch_size,
            )
    )