def _batch_generator_fn()

in horovod/spark/keras/util.py [0:0]


    def _batch_generator_fn(feature_columns, label_columns, sample_weight_col,
                            input_shapes, label_shapes, batch_size, metadata):
        prepare_data_bare_keras = BareKerasUtil._prepare_data_fn(metadata)

        cols = feature_columns + label_columns
        if sample_weight_col:
            cols.append(sample_weight_col)

        def batch_generator(reader, shuffle_buffer_size, shuffle=False):
            while True:
                num_rows_read_sofar = 0
                data = None
                while num_rows_read_sofar < shuffle_buffer_size:
                    # Each call to next reads one row group at a time. reader is an infinite
                    # generator and never ends
                    row_group_data = next(reader)
                    if not data:
                        data = {col: getattr(row_group_data, col) for col in cols}
                    else:
                        for col in cols:
                            data[col] = np.concatenate((data[col],
                                                        getattr(row_group_data, col)))
                    num_rows_read_sofar += row_group_data[0].shape[0]

                # Create a permutation of len of data and use it to shuffle each numpy array
                perm = np.random.permutation(num_rows_read_sofar) \
                    if shuffle else list(range(num_rows_read_sofar))

                inputs = [prepare_data_bare_keras(data[col][perm], col, shape) for col, shape
                          in zip(feature_columns, input_shapes)]
                labels = [prepare_data_bare_keras(data[col][perm], col, shape) for col, shape
                          in zip(label_columns, label_shapes)]

                num_outputs = len(label_columns)
                sample_weights = None
                if sample_weight_col:
                    sample_weights = data[sample_weight_col][perm]

                batch_count = int(len(inputs[0]) / batch_size)
                for i in range(0, batch_count):
                    if sample_weight_col:
                        # We use the same sample weight for all the outputs of the sample
                        sample_weight = sample_weights[i * batch_size:(i + 1) * batch_size]
                        sample_weight_for_batch = [sample_weight for i in range(num_outputs)]

                        yield (
                            [input[i * batch_size:(i + 1) * batch_size] for input in inputs],
                            [label[i * batch_size:(i + 1) * batch_size] for label in labels],
                            sample_weight_for_batch)
                    else:
                        yield (
                            [input[i * batch_size:(i + 1) * batch_size] for input in inputs],
                            [label[i * batch_size:(i + 1) * batch_size] for label in labels])

        return batch_generator