def _on_any_batch_begin()

in smdebug/tensorflow/keras.py [0:0]


    def _on_any_batch_begin(self, batch, mode, logs=None):
        self.start = time.time()
        if self._is_not_supported():
            return

        # set mode for each batch as when users run model.fit() and pass validation data
        # through the optional argument, then mode_begin is not called for the training steps
        # after first evaluation during training
        self.set_mode(mode)

        # Write the gradients of the past step if the writer is still available.
        if self.writer is not None or len(self.writer_map):
            self._close_writers()

        # Addresses callback ordering bug in TF 2.3.0
        if self.step_incremented_in_on_train_begin is False:
            self._increment_step()
        else:
            self.step_incremented_in_on_train_begin = False

        self.profiler_config_parser.load_config()
        self.profiler_config_parser.handle_step_start_python_profiling(mode, self.mode_steps[mode])
        self._start_or_stop_dataloader_profiling(self.mode_steps[mode])

        if self.prepared_tf2_collections is False:
            # sets prepared_collections to True here
            self._prepare_collections_for_tf2()

        if self._prepared_tensors[mode] is False:
            if (is_tf_version_2x() and tf.executing_eagerly()) or self._validate_exec_function(
                self._get_exec_function(mode)
            ):
                self._prepare_layers(mode)
                self._prepare_non_layer_tensors()
                self._prepare_tensors_available_post_step()
                self._prepared_tensors[mode] = True
                # below should be after tensors are processed,
                # so we know that device map is populated
                self._set_chief_worker()
            # else:
            # this will delay the preparation of tensors as the
            # full graph is not built. Gradients are not available
            # at this stage for example

        if self._prepared_tensors[mode]:
            self._prepare_tensors_for_step(mode)
            if self.tensor_refs_to_save_this_step:
                # if saving metric, writer may not be initialized as a result
                self._initialize_writers()

            if not is_tf_version_2x() or (is_tf_version_2x() and not tf.executing_eagerly()):
                self._add_callbacks(mode)