def __init__()

in tensorflow_lattice/python/pwl_calibration_sonnet_module.py [0:0]


  def __init__(self,
               input_keypoints,
               units=1,
               output_min=None,
               output_max=None,
               clamp_min=False,
               clamp_max=False,
               monotonicity="none",
               convexity="none",
               is_cyclic=False,
               kernel_init="equal_heights",
               impute_missing=False,
               missing_input_value=None,
               missing_output_value=None,
               num_projection_iterations=8,
               **kwargs):
    # pyformat: disable
    """Initializes an instance of `PWLCalibration`.

    Args:
      input_keypoints: Ordered list of keypoints of piecewise linear function.
        Can be anything accepted by tf.convert_to_tensor().
      units: Output dimension of the layer. See class comments for details.
      output_min: Minimum output of calibrator.
      output_max: Maximum output of calibrator.
      clamp_min: For monotonic calibrators ensures that output_min is reached.
      clamp_max: For monotonic calibrators ensures that output_max is reached.
      monotonicity: Constraints piecewise linear function to be monotonic using
        'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or
        -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no
        monotonicity constraints.
      convexity: Constraints piecewise linear function to be convex or concave.
        Convexity is indicated by 'convex' or 1, concavity is indicated by
        'concave' or -1, 'none' or 0 indicates no convexity/concavity
        constraints.
        Concavity together with increasing monotonicity as well as convexity
        together with decreasing monotonicity results in diminishing return
        constraints.
        Consider increasing the value of `num_projection_iterations` if
        convexity is specified, especially with larger number of keypoints.
      is_cyclic: Whether the output for last keypoint should be identical to
        output for first keypoint. This is useful for features such as
        "time of day" or "degree of turn". If inputs are discrete and exactly
        match keypoints then is_cyclic will have an effect only if TFL
        regularizers are being used.
      kernel_init: None or one of:
        - String `"equal_heights"`: For pieces of pwl function to have equal
          heights.
        - String `"equal_slopes"`: For pieces of pwl function to have equal
          slopes.
        - Any Sonnet initializer object. If you are passing such object make
          sure that you know how this module uses the variables.
      impute_missing: Whether to learn an output for cases where input data is
        missing. If set to True, either `missing_input_value` should be
        initialized, or the `call()` method should get pair of tensors. See
        class input shape description for more details.
      missing_input_value: If set, all inputs which are equal to this value will
        be considered as missing. Can not be set if `impute_missing` is False.
      missing_output_value: If set, instead of learning output for missing
        inputs, simply maps them into this value. Can not be set if
        `impute_missing` is False.
      num_projection_iterations: Number of iterations of the Dykstra's
        projection algorithm. Constraints are strictly satisfied at the end of
        each update, but the update will be closer to a true L2 projection with
        higher number of iterations. See
        `tfl.pwl_calibration_lib.project_all_constraints` for more details.
      **kwargs: Other args passed to `snt.Module` initializer.

    Raises:
      ValueError: If layer hyperparameters are invalid.
    """
    # pyformat: enable
    super(PWLCalibration, self).__init__(**kwargs)

    pwl_calibration_lib.verify_hyperparameters(
        input_keypoints=input_keypoints,
        output_min=output_min,
        output_max=output_max,
        monotonicity=monotonicity,
        convexity=convexity,
        is_cyclic=is_cyclic)
    if missing_input_value is not None and not impute_missing:
      raise ValueError("'missing_input_value' is specified, but "
                       "'impute_missing' is set to False. "
                       "'missing_input_value': " + str(missing_input_value))
    if missing_output_value is not None and not impute_missing:
      raise ValueError("'missing_output_value' is specified, but "
                       "'impute_missing' is set to False. "
                       "'missing_output_value': " + str(missing_output_value))
    if input_keypoints is None:
      raise ValueError("'input_keypoints' can't be None")
    if monotonicity is None:
      raise ValueError("'monotonicity' can't be None. Did you mean '0'?")

    self.input_keypoints = input_keypoints
    self.units = units
    self.output_min = output_min
    self.output_max = output_max
    self.clamp_min = clamp_min
    self.clamp_max = clamp_max
    (self._output_init_min, self._output_init_max, self._output_min_constraints,
     self._output_max_constraints
    ) = pwl_calibration_lib.convert_all_constraints(self.output_min,
                                                    self.output_max,
                                                    self.clamp_min,
                                                    self.clamp_max)

    self.monotonicity = monotonicity
    self.convexity = convexity
    self.is_cyclic = is_cyclic

    if kernel_init == "equal_heights":
      self.kernel_init = _UniformOutputInitializer(
          output_min=self._output_init_min,
          output_max=self._output_init_max,
          monotonicity=self.monotonicity)
    elif kernel_init == "equal_slopes":
      self.kernel_init = _UniformOutputInitializer(
          output_min=self._output_init_min,
          output_max=self._output_init_max,
          monotonicity=self.monotonicity,
          keypoints=self.input_keypoints)

    self.impute_missing = impute_missing
    self.missing_input_value = missing_input_value
    self.missing_output_value = missing_output_value
    self.num_projection_iterations = num_projection_iterations