def create_kernel_initializer()

in tensorflow_lattice/python/lattice_layer.py [0:0]


def create_kernel_initializer(kernel_initializer_id,
                              lattice_sizes,
                              monotonicities,
                              output_min,
                              output_max,
                              unimodalities,
                              joint_unimodalities,
                              init_min=None,
                              init_max=None):
  """Returns a kernel Keras initializer object from its id.

  This function is used to convert the 'kernel_initializer' parameter in the
  constructor of tfl.Lattice into the corresponding initializer object.

  Args:
    kernel_initializer_id: See the documentation of the 'kernel_initializer'
      parameter in the constructor of tfl.Lattice.
    lattice_sizes: See the documentation of the same parameter in the
      constructor of tfl.Lattice.
    monotonicities: See the documentation of the same parameter in the
      constructor of tfl.Lattice.
    output_min: See the documentation of the same parameter in the constructor
      of tfl.Lattice.
    output_max: See the documentation of the same parameter in the constructor
      of tfl.Lattice.
    unimodalities: See the documentation of the same parameter in the
      constructor of tfl.Lattice.
    joint_unimodalities: See the documentation of the same parameter in the
      constructor of tfl.Lattice.
    init_min: None or lower bound of kernel initialization. If set, init_max
      must also be set.
    init_max: None or upper bound of kernel initialization. If set, init_min
      must also be set.

  Returns:
    The Keras initializer object for the tfl.Lattice kernel variable.

  Raises:
    ValueError: If only one of init_{min/max} is set.
  """
  if ((init_min is not None and init_max is None) or
      (init_min is None and init_max is not None)):
    raise ValueError("Both or neither of init_{min/max} must be set")

  def do_joint_unimodalities_contain_all_features(joint_unimodalities):
    if (joint_unimodalities is None) or (len(joint_unimodalities) != 1):
      return False
    [joint_unimodalities] = joint_unimodalities
    return set(joint_unimodalities[0]) == set(range(len(lattice_sizes)))

  # Initialize joint unimodalities identical to regular ones.
  all_unimodalities = [0] * len(lattice_sizes)
  if unimodalities:
    for i, value in enumerate(unimodalities):
      if value:
        all_unimodalities[i] = value
  if joint_unimodalities:
    for dimensions, direction in joint_unimodalities:
      for dim in dimensions:
        all_unimodalities[dim] = direction

  if kernel_initializer_id in ["linear_initializer", "LinearInitializer"]:
    if init_min is None and init_max is None:
      init_min, init_max = lattice_lib.default_init_params(
          output_min, output_max)

    return LinearInitializer(
        lattice_sizes=lattice_sizes,
        monotonicities=monotonicities,
        output_min=init_min,
        output_max=init_max,
        unimodalities=all_unimodalities)
  elif kernel_initializer_id in [
      "random_monotonic_initializer", "RandomMonotonicInitializer"
  ]:
    if init_min is None and init_max is None:
      init_min, init_max = lattice_lib.default_init_params(
          output_min, output_max)

    return RandomMonotonicInitializer(
        lattice_sizes=lattice_sizes,
        output_min=init_min,
        output_max=init_max,
        unimodalities=all_unimodalities)
  elif kernel_initializer_id in [
      "random_uniform_or_linear_initializer", "RandomUniformOrLinearInitializer"
  ]:
    if do_joint_unimodalities_contain_all_features(joint_unimodalities):
      return create_kernel_initializer("random_uniform", lattice_sizes,
                                       monotonicities, output_min, output_max,
                                       unimodalities, joint_unimodalities,
                                       init_min, init_max)
    return create_kernel_initializer("linear_initializer", lattice_sizes,
                                     monotonicities, output_min, output_max,
                                     unimodalities, joint_unimodalities,
                                     init_min, init_max)
  else:
    # This is needed for Keras deserialization logic to be aware of our custom
    # objects.
    with keras.utils.custom_object_scope({
        "LinearInitializer": LinearInitializer,
        "RandomMonotonicInitializer": RandomMonotonicInitializer,
    }):
      return keras.initializers.get(kernel_initializer_id)