in tensorflow_lattice/python/premade_lib.py [0:0]
def build_lattice_layer(lattice_input, feature_configs, model_config,
layer_output_range, submodel_index, is_inside_ensemble,
dtype):
"""Creates a `tfl.layers.Lattice` layer.
Args:
lattice_input: Input to the lattice layer.
feature_configs: A list of `tfl.configs.FeatureConfig` instances that
specify configurations for each feature.
model_config: Model configuration object describing model architecture.
Should be one of the model configs in `tfl.configs`.
layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.
submodel_index: Corresponding index into submodels.
is_inside_ensemble: If this layer is inside an ensemble.
dtype: dtype
Returns:
A `tfl.layers.Lattice` instance if `model_config.parameterization` is set to
`'all_vertices'` or a `tfl.layers.KroneckerFactoredLattice` instance if
set to `'kronecker_factored'`.
Raises:
ValueError: If `model_config.parameterization` is not one of
`'all_vertices'` or `'kronecker_factored'`.
"""
layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)
(output_min, output_max, output_init_min,
output_init_max) = _output_range(layer_output_range, model_config)
feature_names = [feature_config.name for feature_config in feature_configs]
lattice_sizes = [
feature_config.lattice_size for feature_config in feature_configs
]
lattice_monotonicities = _monotonicities_from_feature_configs(feature_configs)
lattice_unimodalities = [
feature_config.unimodality for feature_config in feature_configs
]
lattice_regularizers = _lattice_regularizers(model_config,
feature_configs) or None
# Construct trust constraints within this lattice.
edgeworth_trusts = []
trapezoid_trusts = []
for conditional_idx, conditional_feature_config in enumerate(feature_configs):
for trust_config in conditional_feature_config.reflects_trust_in or []:
if trust_config.feature_name in feature_names:
main_idx = feature_names.index(trust_config.feature_name)
if trust_config.trust_type == 'edgeworth':
edgeworth_trusts.append(
(main_idx, conditional_idx, trust_config.direction))
elif trust_config.trust_type == 'trapezoid':
trapezoid_trusts.append(
(main_idx, conditional_idx, trust_config.direction))
else:
raise ValueError('Unrecognized trust type: {}'.format(
trust_config.trust_type))
elif is_inside_ensemble and trust_config.trust_type == 'trapezoid':
logging.warning(
'A "main" feature (%s) for a trapezoid trust constraint is not '
'present in a lattice that includes the "conditional" feature '
'(%s). In an ensemble model, this can result in constraint '
'violations. Consider manually setting the ensemble structure if '
'this constraint needs to be satisfied.', trust_config.feature_name,
conditional_feature_config.name)
monotonic_dominances = _dominance_constraints_from_feature_configs(
feature_configs)
if model_config.parameterization == 'all_vertices':
layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)
kernel_initializer = lattice_layer.LinearInitializer(
lattice_sizes=lattice_sizes,
monotonicities=lattice_monotonicities,
unimodalities=lattice_unimodalities,
output_min=output_init_min,
output_max=output_init_max)
return lattice_layer.Lattice(
lattice_sizes=lattice_sizes,
monotonicities=lattice_monotonicities,
unimodalities=lattice_unimodalities,
edgeworth_trusts=edgeworth_trusts,
trapezoid_trusts=trapezoid_trusts,
monotonic_dominances=monotonic_dominances,
output_min=output_min,
output_max=output_max,
clip_inputs=False,
interpolation=model_config.interpolation,
kernel_regularizer=lattice_regularizers,
kernel_initializer=kernel_initializer,
dtype=dtype,
name=layer_name)(
lattice_input)
elif model_config.parameterization == 'kronecker_factored':
layer_name = '{}_{}'.format(KFL_LAYER_NAME, submodel_index)
kernel_initializer = kfll.KFLRandomMonotonicInitializer(
monotonicities=lattice_monotonicities,
init_min=output_init_min,
init_max=output_init_max,
seed=model_config.random_seed)
scale_initializer = kfll.ScaleInitializer(
output_min=output_min, output_max=output_max)
return kfll.KroneckerFactoredLattice(
lattice_sizes=lattice_sizes[0],
num_terms=model_config.num_terms,
monotonicities=lattice_monotonicities,
output_min=output_min,
output_max=output_max,
clip_inputs=False,
kernel_initializer=kernel_initializer,
scale_initializer=scale_initializer,
dtype=dtype,
name=layer_name)(
lattice_input)
else:
raise ValueError('Unknown type of parameterization: {}'.format(
model_config.parameterization))