def call()

in tensorflow_compression/python/layers/gdn.py [0:0]


  def call(self, inputs) -> tf.Tensor:
    inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
    rank = inputs.shape.rank
    if rank is None or rank < 2:
      raise ValueError(f"Input tensor must have at least rank 2, received "
                       f"shape {inputs.shape}.")

    if self.rectify:
      inputs = tf.nn.relu(inputs)

    # Optimize for fixed alphas.
    if not callable(self.alpha_parameter) and self.alpha == 1 and self.rectify:
      norm_pool = inputs
    elif not callable(self.alpha_parameter) and self.alpha == 1:
      norm_pool = abs(inputs)
    elif not callable(self.alpha_parameter) and self.alpha == 2:
      norm_pool = tf.math.square(inputs)
    else:
      norm_pool = inputs ** self.alpha

    # Compute normalization pool.
    if rank == 2:
      norm_pool = tf.linalg.matmul(norm_pool, self.gamma)
      norm_pool = tf.nn.bias_add(norm_pool, self.beta)
    elif self.data_format == "channels_last" and rank <= 5:
      shape = self.gamma.shape
      gamma = tf.reshape(self.gamma, (rank - 2) * [1] + shape)
      norm_pool = tf.nn.convolution(norm_pool, gamma, padding="VALID")
      norm_pool = tf.nn.bias_add(norm_pool, self.beta)
    else:  # generic implementation
      # This puts channels in the last dimension regardless of input.
      norm_pool = tf.linalg.tensordot(
          norm_pool, self.gamma, [[self._channel_axis], [0]])
      norm_pool += self.beta
      if self.data_format == "channels_first":
        # Return to channels_first format if necessary.
        axes = list(range(rank - 1))
        axes.insert(1, rank - 1)
        norm_pool = tf.transpose(norm_pool, axes)

    # Optimize for fixed epsilons.
    if not callable(self.epsilon_parameter) and self.epsilon == 1:
      pass
    elif not callable(self.epsilon_parameter) and self.epsilon == .5:
      norm_pool = tf.math.sqrt(norm_pool)
    else:
      norm_pool = norm_pool ** self.epsilon

    if self.inverse:
      return inputs * norm_pool
    else:
      return inputs / norm_pool