in tensorflow_compression/python/layers/gdn.py [0:0]
def call(self, inputs) -> tf.Tensor:
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
rank = inputs.shape.rank
if rank is None or rank < 2:
raise ValueError(f"Input tensor must have at least rank 2, received "
f"shape {inputs.shape}.")
if self.rectify:
inputs = tf.nn.relu(inputs)
# Optimize for fixed alphas.
if not callable(self.alpha_parameter) and self.alpha == 1 and self.rectify:
norm_pool = inputs
elif not callable(self.alpha_parameter) and self.alpha == 1:
norm_pool = abs(inputs)
elif not callable(self.alpha_parameter) and self.alpha == 2:
norm_pool = tf.math.square(inputs)
else:
norm_pool = inputs ** self.alpha
# Compute normalization pool.
if rank == 2:
norm_pool = tf.linalg.matmul(norm_pool, self.gamma)
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
elif self.data_format == "channels_last" and rank <= 5:
shape = self.gamma.shape
gamma = tf.reshape(self.gamma, (rank - 2) * [1] + shape)
norm_pool = tf.nn.convolution(norm_pool, gamma, padding="VALID")
norm_pool = tf.nn.bias_add(norm_pool, self.beta)
else: # generic implementation
# This puts channels in the last dimension regardless of input.
norm_pool = tf.linalg.tensordot(
norm_pool, self.gamma, [[self._channel_axis], [0]])
norm_pool += self.beta
if self.data_format == "channels_first":
# Return to channels_first format if necessary.
axes = list(range(rank - 1))
axes.insert(1, rank - 1)
norm_pool = tf.transpose(norm_pool, axes)
# Optimize for fixed epsilons.
if not callable(self.epsilon_parameter) and self.epsilon == 1:
pass
elif not callable(self.epsilon_parameter) and self.epsilon == .5:
norm_pool = tf.math.sqrt(norm_pool)
else:
norm_pool = norm_pool ** self.epsilon
if self.inverse:
return inputs * norm_pool
else:
return inputs / norm_pool