in tensorflow_lattice/python/lattice_layer.py [0:0]
def create_kernel_initializer(kernel_initializer_id,
lattice_sizes,
monotonicities,
output_min,
output_max,
unimodalities,
joint_unimodalities,
init_min=None,
init_max=None):
"""Returns a kernel Keras initializer object from its id.
This function is used to convert the 'kernel_initializer' parameter in the
constructor of tfl.Lattice into the corresponding initializer object.
Args:
kernel_initializer_id: See the documentation of the 'kernel_initializer'
parameter in the constructor of tfl.Lattice.
lattice_sizes: See the documentation of the same parameter in the
constructor of tfl.Lattice.
monotonicities: See the documentation of the same parameter in the
constructor of tfl.Lattice.
output_min: See the documentation of the same parameter in the constructor
of tfl.Lattice.
output_max: See the documentation of the same parameter in the constructor
of tfl.Lattice.
unimodalities: See the documentation of the same parameter in the
constructor of tfl.Lattice.
joint_unimodalities: See the documentation of the same parameter in the
constructor of tfl.Lattice.
init_min: None or lower bound of kernel initialization. If set, init_max
must also be set.
init_max: None or upper bound of kernel initialization. If set, init_min
must also be set.
Returns:
The Keras initializer object for the tfl.Lattice kernel variable.
Raises:
ValueError: If only one of init_{min/max} is set.
"""
if ((init_min is not None and init_max is None) or
(init_min is None and init_max is not None)):
raise ValueError("Both or neither of init_{min/max} must be set")
def do_joint_unimodalities_contain_all_features(joint_unimodalities):
if (joint_unimodalities is None) or (len(joint_unimodalities) != 1):
return False
[joint_unimodalities] = joint_unimodalities
return set(joint_unimodalities[0]) == set(range(len(lattice_sizes)))
# Initialize joint unimodalities identical to regular ones.
all_unimodalities = [0] * len(lattice_sizes)
if unimodalities:
for i, value in enumerate(unimodalities):
if value:
all_unimodalities[i] = value
if joint_unimodalities:
for dimensions, direction in joint_unimodalities:
for dim in dimensions:
all_unimodalities[dim] = direction
if kernel_initializer_id in ["linear_initializer", "LinearInitializer"]:
if init_min is None and init_max is None:
init_min, init_max = lattice_lib.default_init_params(
output_min, output_max)
return LinearInitializer(
lattice_sizes=lattice_sizes,
monotonicities=monotonicities,
output_min=init_min,
output_max=init_max,
unimodalities=all_unimodalities)
elif kernel_initializer_id in [
"random_monotonic_initializer", "RandomMonotonicInitializer"
]:
if init_min is None and init_max is None:
init_min, init_max = lattice_lib.default_init_params(
output_min, output_max)
return RandomMonotonicInitializer(
lattice_sizes=lattice_sizes,
output_min=init_min,
output_max=init_max,
unimodalities=all_unimodalities)
elif kernel_initializer_id in [
"random_uniform_or_linear_initializer", "RandomUniformOrLinearInitializer"
]:
if do_joint_unimodalities_contain_all_features(joint_unimodalities):
return create_kernel_initializer("random_uniform", lattice_sizes,
monotonicities, output_min, output_max,
unimodalities, joint_unimodalities,
init_min, init_max)
return create_kernel_initializer("linear_initializer", lattice_sizes,
monotonicities, output_min, output_max,
unimodalities, joint_unimodalities,
init_min, init_max)
else:
# This is needed for Keras deserialization logic to be aware of our custom
# objects.
with keras.utils.custom_object_scope({
"LinearInitializer": LinearInitializer,
"RandomMonotonicInitializer": RandomMonotonicInitializer,
}):
return keras.initializers.get(kernel_initializer_id)