def call()

in tensorflow_hub/keras_layer.py [0:0]


  def call(self, inputs, training=None):
    # These checks happen here and not in __init__, because self.trainable is
    # a mutable public attribute.
    self._check_trainability()

    # We basically want to call this...
    args = []
    kwargs = self._arguments.copy()
    if self._signature and isinstance(inputs, dict):
      kwargs.update(inputs)
    else:
      args.append(inputs)
    f = functools.partial(self._callable, *args, **kwargs)
    # ...but we may also have to pass a Python boolean for `training`, which
    # is the logical "and" of this layer's trainability and what the surrounding
    # model is doing (analogous to tf.keras.layers.BatchNormalization in TF2).
    # For the latter, we have to look in two places: the `training` argument,
    # or else Keras' global `learning_phase`, which might actually be a tensor.
    if not self._has_training_argument:
      result = f()
    else:
      if self.trainable:
        if training is None:
          training = tf.keras.backend.learning_phase()
      else:
        # Behave like BatchNormalization. (Dropout is different, b/181839368.)
        training = False
      result = smart_cond.smart_cond(training,
                                     lambda: f(training=True),
                                     lambda: f(training=False))

    # Unwrap dicts returned by signatures.
    if self._output_key:
      if not isinstance(result, dict):
        raise ValueError("Specifying `output_key` is forbidden if output "
                         "type %s is not a dict." % type(result))
      if self._output_key not in result:
        raise ValueError(
            "KerasLayer output does not contain the output key %s "
            "(available: %s)." % (self._output_key, result.keys()))
      result = result[self._output_key]

    result = self._apply_output_shape_if_set(inputs, result)
    return result