in horovod/spark/keras/util.py [0:0]
def _batch_generator_fn(feature_columns, label_columns, sample_weight_col,
input_shapes, label_shapes, batch_size, metadata):
prepare_data_bare_keras = BareKerasUtil._prepare_data_fn(metadata)
cols = feature_columns + label_columns
if sample_weight_col:
cols.append(sample_weight_col)
def batch_generator(reader, shuffle_buffer_size, shuffle=False):
while True:
num_rows_read_sofar = 0
data = None
while num_rows_read_sofar < shuffle_buffer_size:
# Each call to next reads one row group at a time. reader is an infinite
# generator and never ends
row_group_data = next(reader)
if not data:
data = {col: getattr(row_group_data, col) for col in cols}
else:
for col in cols:
data[col] = np.concatenate((data[col],
getattr(row_group_data, col)))
num_rows_read_sofar += row_group_data[0].shape[0]
# Create a permutation of len of data and use it to shuffle each numpy array
perm = np.random.permutation(num_rows_read_sofar) \
if shuffle else list(range(num_rows_read_sofar))
inputs = [prepare_data_bare_keras(data[col][perm], col, shape) for col, shape
in zip(feature_columns, input_shapes)]
labels = [prepare_data_bare_keras(data[col][perm], col, shape) for col, shape
in zip(label_columns, label_shapes)]
num_outputs = len(label_columns)
sample_weights = None
if sample_weight_col:
sample_weights = data[sample_weight_col][perm]
batch_count = int(len(inputs[0]) / batch_size)
for i in range(0, batch_count):
if sample_weight_col:
# We use the same sample weight for all the outputs of the sample
sample_weight = sample_weights[i * batch_size:(i + 1) * batch_size]
sample_weight_for_batch = [sample_weight for i in range(num_outputs)]
yield (
[input[i * batch_size:(i + 1) * batch_size] for input in inputs],
[label[i * batch_size:(i + 1) * batch_size] for label in labels],
sample_weight_for_batch)
else:
yield (
[input[i * batch_size:(i + 1) * batch_size] for input in inputs],
[label[i * batch_size:(i + 1) * batch_size] for label in labels])
return batch_generator