in tensorflow_lattice/python/estimators.py [0:0]
def __init__(self,
model_config,
feature_columns,
feature_analysis_input_fn=None,
feature_analysis_weight_column=None,
feature_analysis_weight_reduction='mean',
prefitting_input_fn=None,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
prefitting_optimizer='Adagrad',
prefitting_steps=None,
config=None,
warm_start_from=None,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
loss_fn=None,
dtype=tf.float32):
"""Initializes a `CannedClassifier` instance.
Args:
model_config: Model configuration object describing model architecutre.
Should be one of the model configs in `tfl.configs`.
feature_columns: An iterable containing all the feature columns used by
the model.
feature_analysis_input_fn: An input_fn used to calculate statistics about
features and labels in order to setup calibration keypoint and values.
feature_analysis_weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights used for calculating weighted feature statistics (quantiles).
Can be the same as `weight_column`.
feature_analysis_weight_reduction: Reduction used on weights when
aggregating repeated values during feature analysis. Can be either 'sum'
or 'mean'.
prefitting_input_fn: An input_fn used in the pre fitting stage to estimate
non-linear feature interactions. Required for crystals models.
Prefitting typically uses the same dataset as the main training, but
with fewer epochs.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
n_classes: Number of label classes. Defaults to 2, namely binary
classification. Must be > 1.
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is only used by the estimator head to down weight or boost
examples during training. It will be multiplied by the loss of the
example. If it is a string, it is used as a key to fetch the weight
tensor from the `features` dictionary output of the input function. If
it is a `_NumericColumn`, a raw tensor is fetched by key
`weight_column.key`, then weight_column.normalizer_fn is applied on it
to get the weight tensor. Note that in both cases 'weight_column' should
*not* be a member of the 'feature_columns' parameter to the constructor
since these will be used for both serving and training.
label_vocabulary: A list of strings represents possible label values. If
given, labels must be string type and have any value in
`label_vocabulary`. If it is not given, that means labels are already
encoded as integer or float within [0, 1] for `n_classes=2` and encoded
as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also
there will be errors if vocabulary is not provided and labels are
string.
optimizer: An instance of `tf.Optimizer` used to train the model. Can also
be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
callable. Defaults to Adagrad optimizer.
prefitting_optimizer: An instance of `tf.Optimizer` used to train the
model during the pre-fitting stage. Can also be a string (one of
'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to
Adagrad optimizer.
prefitting_steps: Number of steps for which to pretraing train the model
during the prefitting stage. If None, train forever or train until
prefitting_input_fn generates the tf.errors.OutOfRange error or
StopIteration exception.
config: `RunConfig` object to configure the runtime settings.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
a `WarmStartSettings` object to fully configure warm-starting. If the
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
loss_fn: Optional loss function.
dtype: dtype of layers used in the model.
"""
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
config, model_dir)
model_dir = config.model_dir
if n_classes == 2:
head = binary_class_head.BinaryClassHead(
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
loss_fn=loss_fn)
else:
head = multi_class_head.MultiClassHead(
n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
loss_fn=loss_fn)
label_dimension = 1 if n_classes == 2 else n_classes
model_config = copy.deepcopy(model_config)
_update_by_feature_columns(model_config, feature_columns)
_finalize_keypoints(
model_config=model_config,
config=config,
feature_columns=feature_columns,
feature_analysis_input_fn=feature_analysis_input_fn,
feature_analysis_weight_column=feature_analysis_weight_column,
feature_analysis_weight_reduction=feature_analysis_weight_reduction,
logits_output=True)
_verify_config(model_config, feature_columns)
_finalize_model_structure(
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
prefitting_input_fn=prefitting_input_fn,
prefitting_optimizer=prefitting_optimizer,
prefitting_steps=prefitting_steps,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from,
dtype=dtype)
model_fn = _get_model_fn(
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
optimizer=optimizer,
dtype=dtype)
super(CannedClassifier, self).__init__(
model_fn=model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)