in keras/engine/training_generator.py [0:0]
def fit_generator(model,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0):
"""See docstring for `Model.fit_generator`."""
wait_time = 0.01 # in seconds
epoch = initial_epoch
do_validation = bool(validation_data)
model._make_train_function()
if do_validation:
model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps_per_epoch is None:
if is_sequence:
steps_per_epoch = len(generator)
else:
raise ValueError('`steps_per_epoch=None` is only valid for a'
' generator based on the '
'`keras.utils.Sequence`'
' class. Please specify `steps_per_epoch` '
'or use the `keras.utils.Sequence` class.')
# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if (val_gen and not isinstance(validation_data, Sequence) and
not validation_steps):
raise ValueError('`validation_steps=None` is only valid for a'
' generator based on the `keras.utils.Sequence`'
' class. Please specify `validation_steps` or use'
' the `keras.utils.Sequence` class.')
# Prepare display labels.
out_labels = model.metrics_names
callback_metrics = out_labels + ['val_' + n for n in out_labels]
# prepare callbacks
model.history = cbks.History()
_callbacks = [cbks.BaseLogger(
stateful_metrics=model.stateful_metric_names)]
if verbose:
_callbacks.append(
cbks.ProgbarLogger(
count_mode='steps',
stateful_metrics=model.stateful_metric_names))
_callbacks += (callbacks or []) + [model.history]
callbacks = cbks.CallbackList(_callbacks)
# it's possible to callback a different model than self:
if hasattr(model, 'callback_model') and model.callback_model:
callback_model = model.callback_model
else:
callback_model = model
callbacks.set_model(callback_model)
callbacks.set_params({
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
})
callbacks.on_train_begin()
enqueuer = None
val_enqueuer = None
try:
if do_validation:
if val_gen and workers > 0:
# Create an Enqueuer that can be reused
val_data = validation_data
if isinstance(val_data, Sequence):
val_enqueuer = OrderedEnqueuer(
val_data,
use_multiprocessing=use_multiprocessing)
validation_steps = validation_steps or len(val_data)
else:
val_enqueuer = GeneratorEnqueuer(
val_data,
use_multiprocessing=use_multiprocessing)
val_enqueuer.start(workers=workers,
max_queue_size=max_queue_size)
val_enqueuer_gen = val_enqueuer.get()
elif val_gen:
val_data = validation_data
if isinstance(val_data, Sequence):
val_enqueuer_gen = iter_sequence_infinite(val_data)
validation_steps = validation_steps or len(val_data)
else:
val_enqueuer_gen = val_data
else:
# Prepare data for validation
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
else:
raise ValueError('`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' +
str(validation_data))
val_x, val_y, val_sample_weights = model._standardize_user_data(
val_x, val_y, val_sample_weight)
val_data = val_x + val_y + val_sample_weights
if model.uses_learning_phase and not isinstance(K.learning_phase(),
int):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
if workers > 0:
if is_sequence:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle)
else:
enqueuer = GeneratorEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
callback_model.stop_training = False
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
for m in model.stateful_metric_functions:
m.reset_states()
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
while steps_done < steps_per_epoch:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
if len(generator_output) == 2:
x, y = generator_output
sample_weight = None
elif len(generator_output) == 3:
x, y, sample_weight = generator_output
else:
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
# build batch logs
batch_logs = {}
if x is None or len(x) == 0:
# Handle data tensors support when no input given
# step-size = 1 for data tensors
batch_size = 1
elif isinstance(x, list):
batch_size = x[0].shape[0]
elif isinstance(x, dict):
batch_size = list(x.values())[0].shape[0]
else:
batch_size = x.shape[0]
batch_logs['batch'] = batch_index
batch_logs['size'] = batch_size
callbacks.on_batch_begin(batch_index, batch_logs)
outs = model.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight)
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
batch_index += 1
steps_done += 1
# Epoch finished.
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = model.evaluate_generator(
val_enqueuer_gen,
validation_steps,
workers=0)
else:
# No need for try/except because
# data has already been validated.
val_outs = model.evaluate(
val_x, val_y,
batch_size=batch_size,
sample_weight=val_sample_weights,
verbose=0)
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
if callback_model.stop_training:
break
callbacks.on_epoch_end(epoch, epoch_logs)
epoch += 1
if callback_model.stop_training:
break
finally:
try:
if enqueuer is not None:
enqueuer.stop()
finally:
if val_enqueuer is not None:
val_enqueuer.stop()
callbacks.on_train_end()
return model.history