in smdebug/tensorflow/keras.py [0:0]
def _on_any_batch_begin(self, batch, mode, logs=None):
self.start = time.time()
if self._is_not_supported():
return
# set mode for each batch as when users run model.fit() and pass validation data
# through the optional argument, then mode_begin is not called for the training steps
# after first evaluation during training
self.set_mode(mode)
# Write the gradients of the past step if the writer is still available.
if self.writer is not None or len(self.writer_map):
self._close_writers()
# Addresses callback ordering bug in TF 2.3.0
if self.step_incremented_in_on_train_begin is False:
self._increment_step()
else:
self.step_incremented_in_on_train_begin = False
self.profiler_config_parser.load_config()
self.profiler_config_parser.handle_step_start_python_profiling(mode, self.mode_steps[mode])
self._start_or_stop_dataloader_profiling(self.mode_steps[mode])
if self.prepared_tf2_collections is False:
# sets prepared_collections to True here
self._prepare_collections_for_tf2()
if self._prepared_tensors[mode] is False:
if (is_tf_version_2x() and tf.executing_eagerly()) or self._validate_exec_function(
self._get_exec_function(mode)
):
self._prepare_layers(mode)
self._prepare_non_layer_tensors()
self._prepare_tensors_available_post_step()
self._prepared_tensors[mode] = True
# below should be after tensors are processed,
# so we know that device map is populated
self._set_chief_worker()
# else:
# this will delay the preparation of tensors as the
# full graph is not built. Gradients are not available
# at this stage for example
if self._prepared_tensors[mode]:
self._prepare_tensors_for_step(mode)
if self.tensor_refs_to_save_this_step:
# if saving metric, writer may not be initialized as a result
self._initialize_writers()
if not is_tf_version_2x() or (is_tf_version_2x() and not tf.executing_eagerly()):
self._add_callbacks(mode)