in tensorflow_lattice/python/rtl_layer.py [0:0]
def build(self, input_shape):
"""Standard Keras build() method."""
rtl_lib.verify_hyperparameters(
lattice_size=self.lattice_size, input_shape=input_shape)
# Convert kernel regularizers to proper form (tuples).
kernel_regularizer = self.kernel_regularizer
if isinstance(self.kernel_regularizer, list):
if isinstance(self.kernel_regularizer[0], six.string_types):
kernel_regularizer = tuple(self.kernel_regularizer)
else:
kernel_regularizer = [tuple(r) for r in self.kernel_regularizer]
self._rtl_structure = self._get_rtl_structure(input_shape)
# dict from monotonicities to the lattice layers with those monotonicities.
self._lattice_layers = {}
for monotonicities, inputs_for_units in self._rtl_structure:
monotonicities_str = ''.join(
[str(monotonicity) for monotonicity in monotonicities])
# Passthrough names for reconstructing model graph.
inputs_for_units_name = '{}_{}'.format(INPUTS_FOR_UNITS_PREFIX,
monotonicities_str)
# Use control dependencies to save inputs_for_units as graph constant for
# visualisation toolbox to be able to recover it from saved graph.
# Wrap this constant into pure op since in TF 2.0 there are issues passing
# tensors into control_dependencies.
with tf.control_dependencies([
tf.constant(
inputs_for_units, dtype=tf.int32, name=inputs_for_units_name)
]):
units = len(inputs_for_units)
if self.parameterization == 'all_vertices':
layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str)
lattice_sizes = [self.lattice_size] * self.lattice_rank
kernel_initializer = lattice_layer.create_kernel_initializer(
kernel_initializer_id=self.kernel_initializer,
lattice_sizes=lattice_sizes,
monotonicities=monotonicities,
output_min=self.output_min,
output_max=self.output_max,
unimodalities=None,
joint_unimodalities=None,
init_min=self.init_min,
init_max=self.init_max)
self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice(
lattice_sizes=lattice_sizes,
units=units,
monotonicities=monotonicities,
output_min=self.output_min,
output_max=self.output_max,
num_projection_iterations=self.num_projection_iterations,
monotonic_at_every_step=self.monotonic_at_every_step,
clip_inputs=self.clip_inputs,
interpolation=self.interpolation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name=layer_name,
)
elif self.parameterization == 'kronecker_factored':
layer_name = '{}_{}'.format(RTL_KFL_NAME, monotonicities_str)
kernel_initializer = kfll.create_kernel_initializer(
kernel_initializer_id=self.kernel_initializer,
monotonicities=monotonicities,
output_min=self.output_min,
output_max=self.output_max,
init_min=self.init_min,
init_max=self.init_max)
self._lattice_layers[str(
monotonicities)] = kfll.KroneckerFactoredLattice(
lattice_sizes=self.lattice_size,
units=units,
num_terms=self.num_terms,
monotonicities=monotonicities,
output_min=self.output_min,
output_max=self.output_max,
clip_inputs=self.clip_inputs,
kernel_initializer=kernel_initializer,
scale_initializer='scale_initializer',
name=layer_name)
else:
raise ValueError('Unknown type of parameterization: {}'.format(
self.parameterization))
super(RTL, self).build(input_shape)