def __init__()

in tensorflow_lattice/python/estimators.py [0:0]


  def __init__(self,
               model_config,
               feature_columns,
               feature_analysis_input_fn=None,
               feature_analysis_weight_column=None,
               feature_analysis_weight_reduction='mean',
               prefitting_input_fn=None,
               model_dir=None,
               n_classes=2,
               weight_column=None,
               label_vocabulary=None,
               optimizer='Adagrad',
               prefitting_optimizer='Adagrad',
               prefitting_steps=None,
               config=None,
               warm_start_from=None,
               loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
               loss_fn=None,
               dtype=tf.float32):
    """Initializes a `CannedClassifier` instance.

    Args:
      model_config: Model configuration object describing model architecutre.
        Should be one of the model configs in `tfl.configs`.
      feature_columns: An iterable containing all the feature columns used by
        the model.
      feature_analysis_input_fn: An input_fn used to calculate statistics about
        features and labels in order to setup calibration keypoint and values.
      feature_analysis_weight_column: A string or a `_NumericColumn` created by
        `tf.feature_column.numeric_column` defining feature column representing
        weights used for calculating weighted feature statistics (quantiles).
        Can be the same as `weight_column`.
      feature_analysis_weight_reduction: Reduction used on weights when
        aggregating repeated values during feature analysis. Can be either 'sum'
        or 'mean'.
      prefitting_input_fn: An input_fn used in the pre fitting stage to estimate
        non-linear feature interactions. Required for crystals models.
        Prefitting typically uses the same dataset as the main training, but
        with fewer epochs.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      n_classes: Number of label classes. Defaults to 2, namely binary
        classification. Must be > 1.
      weight_column: A string or a `_NumericColumn` created by
        `tf.feature_column.numeric_column` defining feature column representing
        weights. It is only used by the estimator head to down weight or boost
        examples during training. It will be multiplied by the loss of the
        example. If it is a string, it is used as a key to fetch the weight
        tensor from the `features` dictionary output of the input function. If
        it is a `_NumericColumn`, a raw tensor is fetched by key
        `weight_column.key`, then weight_column.normalizer_fn is applied on it
        to get the weight tensor. Note that in both cases 'weight_column' should
        *not* be a member of the 'feature_columns'  parameter to the constructor
        since these will be used for both serving and training.
      label_vocabulary: A list of strings represents possible label values. If
        given, labels must be string type and have any value in
        `label_vocabulary`. If it is not given, that means labels are already
        encoded as integer or float within [0, 1] for `n_classes=2` and encoded
        as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also
        there will be errors if vocabulary is not provided and labels are
        string.
      optimizer: An instance of `tf.Optimizer` used to train the model. Can also
        be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
        callable. Defaults to Adagrad optimizer.
      prefitting_optimizer: An instance of `tf.Optimizer` used to train the
        model during the pre-fitting stage. Can also be a string (one of
        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to
        Adagrad optimizer.
      prefitting_steps: Number of steps for which to pretraing train the model
        during the prefitting stage. If None, train forever or train until
        prefitting_input_fn generates the tf.errors.OutOfRange error or
        StopIteration exception.
      config: `RunConfig` object to configure the runtime settings.
      warm_start_from: A string filepath to a checkpoint to warm-start from, or
        a `WarmStartSettings` object to fully configure warm-starting.  If the
        string filepath is provided instead of a `WarmStartSettings`, then all
        weights are warm-started, and it is assumed that vocabularies and Tensor
        names are unchanged.
      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
      loss_fn: Optional loss function.
      dtype: dtype of layers used in the model.
    """
    config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
        config, model_dir)
    model_dir = config.model_dir

    if n_classes == 2:
      head = binary_class_head.BinaryClassHead(
          weight_column=weight_column,
          label_vocabulary=label_vocabulary,
          loss_reduction=loss_reduction,
          loss_fn=loss_fn)
    else:
      head = multi_class_head.MultiClassHead(
          n_classes,
          weight_column=weight_column,
          label_vocabulary=label_vocabulary,
          loss_reduction=loss_reduction,
          loss_fn=loss_fn)

    label_dimension = 1 if n_classes == 2 else n_classes

    model_config = copy.deepcopy(model_config)
    _update_by_feature_columns(model_config, feature_columns)

    _finalize_keypoints(
        model_config=model_config,
        config=config,
        feature_columns=feature_columns,
        feature_analysis_input_fn=feature_analysis_input_fn,
        feature_analysis_weight_column=feature_analysis_weight_column,
        feature_analysis_weight_reduction=feature_analysis_weight_reduction,
        logits_output=True)

    _verify_config(model_config, feature_columns)

    _finalize_model_structure(
        label_dimension=label_dimension,
        feature_columns=feature_columns,
        head=head,
        model_config=model_config,
        prefitting_input_fn=prefitting_input_fn,
        prefitting_optimizer=prefitting_optimizer,
        prefitting_steps=prefitting_steps,
        model_dir=model_dir,
        config=config,
        warm_start_from=warm_start_from,
        dtype=dtype)

    model_fn = _get_model_fn(
        label_dimension=label_dimension,
        feature_columns=feature_columns,
        head=head,
        model_config=model_config,
        optimizer=optimizer,
        dtype=dtype)

    super(CannedClassifier, self).__init__(
        model_fn=model_fn,
        model_dir=model_dir,
        config=config,
        warm_start_from=warm_start_from)