in keras/engine/training_generator.py [0:0]
def predict_generator(model, generator,
steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
model._make_predict_function()
steps_done = 0
wait_time = 0.01
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
' based on the `keras.utils.Sequence` class.'
' Please specify `steps` or use the'
' `keras.utils.Sequence` class.')
enqueuer = None
try:
if workers > 0:
if is_sequence:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
if verbose == 1:
progbar = Progbar(target=steps)
while steps_done < steps:
generator_output = next(output_generator)
if isinstance(generator_output, tuple):
# Compatibility with the generators
# used for training.
if len(generator_output) == 2:
x, _ = generator_output
elif len(generator_output) == 3:
x, _, _ = generator_output
else:
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
else:
# Assumes a generator that only
# yields inputs (not targets and sample weights).
x = generator_output
outs = model.predict_on_batch(x)
outs = to_list(outs)
if not all_outs:
for out in outs:
all_outs.append([])
for i, out in enumerate(outs):
all_outs[i].append(out)
steps_done += 1
if verbose == 1:
progbar.update(steps_done)
finally:
if enqueuer is not None:
enqueuer.stop()
if len(all_outs) == 1:
if steps_done == 1:
return all_outs[0][0]
else:
return np.concatenate(all_outs[0])
if steps_done == 1:
return [out[0] for out in all_outs]
else:
return [np.concatenate(out) for out in all_outs]