in keras/engine/training_generator.py [0:0]
def evaluate_generator(model, generator,
steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
model._make_test_function()
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(model.metrics_names)
if str(name) in model.stateful_metric_names]
else:
stateful_metric_indices = []
steps_done = 0
wait_time = 0.01
outs_per_batch = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
' based on the `keras.utils.Sequence` class.'
' Please specify `steps` or use the'
' `keras.utils.Sequence` class.')
enqueuer = None
try:
if workers > 0:
if is_sequence:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
if verbose == 1:
progbar = Progbar(target=steps)
while steps_done < steps:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
raise ValueError('Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: ' +
str(generator_output))
if len(generator_output) == 2:
x, y = generator_output
sample_weight = None
elif len(generator_output) == 3:
x, y, sample_weight = generator_output
else:
raise ValueError('Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: ' +
str(generator_output))
outs = model.test_on_batch(x, y, sample_weight=sample_weight)
outs = to_list(outs)
outs_per_batch.append(outs)
if x is None or len(x) == 0:
# Handle data tensors support when no input given
# step-size = 1 for data tensors
batch_size = 1
elif isinstance(x, list):
batch_size = x[0].shape[0]
elif isinstance(x, dict):
batch_size = list(x.values())[0].shape[0]
else:
batch_size = x.shape[0]
if batch_size == 0:
raise ValueError('Received an empty batch. '
'Batches should contain '
'at least one item.')
steps_done += 1
batch_sizes.append(batch_size)
if verbose == 1:
progbar.update(steps_done)
finally:
if enqueuer is not None:
enqueuer.stop()
averages = []
for i in range(len(outs)):
if i not in stateful_metric_indices:
averages.append(np.average([out[i] for out in outs_per_batch],
weights=batch_sizes))
else:
averages.append(np.float64(outs_per_batch[-1][i]))
return unpack_singleton(averages)