in tensorflow_lattice/python/estimators.py [0:0]
def _finalize_model_structure(model_config, label_dimension, feature_columns,
head, prefitting_input_fn, prefitting_optimizer,
prefitting_steps, model_dir, config,
warm_start_from, dtype):
"""Sets up the lattice ensemble in model_config with requested algorithm."""
if (not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or
isinstance(model_config.lattices, list)):
return
# TODO: If warmstarting, look for the previous ensemble file.
if warm_start_from:
raise ValueError('Warm starting lattice ensembles without explicitly '
'defined lattices is not supported yet.')
if feature_columns:
feature_names = [feature_column.name for feature_column in feature_columns]
else:
feature_names = [
feature_config.name for feature_config in model_config.feature_configs
]
if model_config.lattice_rank > len(feature_names):
raise ValueError(
'lattice_rank {} cannot be larger than the number of features: {}'
.format(model_config.lattice_rank, feature_names))
if model_config.num_lattices * model_config.lattice_rank < len(feature_names):
raise ValueError(
'Model with {}x{}d lattices is not large enough for all features: {}'
.format(model_config.num_lattices, model_config.lattice_rank,
feature_names))
ensemble_structure_filename = os.path.join(model_dir,
_ENSEMBLE_STRUCTURE_FILE)
if ((config is None or config.is_chief) and
not tf.io.gfile.exists(ensemble_structure_filename)):
if model_config.lattices not in ['random', 'crystals', 'rtl_layer']:
raise ValueError('Unsupported ensemble structure: {}'.format(
model_config.lattices))
if model_config.lattices == 'random':
premade_lib.set_random_lattice_ensemble(model_config, feature_names)
elif model_config.lattices == 'crystals':
_set_crystals_lattice_ensemble(
feature_names=feature_names,
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
prefitting_input_fn=prefitting_input_fn,
prefitting_optimizer=prefitting_optimizer,
prefitting_steps=prefitting_steps,
config=config,
dtype=dtype)
if (model_config.fix_ensemble_for_2d_constraints and
model_config.lattices != 'rtl_layer'):
# Note that we currently only support monotonicity and bound constraints
# for RTL.
_fix_ensemble_for_2d_constraints(model_config, feature_names)
# Save lattices to file as the chief worker.
tmp_ensemble_structure_filename = ensemble_structure_filename + 'tmp'
with tf.io.gfile.GFile(tmp_ensemble_structure_filename,
'w') as ensemble_structure_file:
ensemble_structure_file.write(json.dumps(model_config.lattices, indent=2))
tf.io.gfile.rename(tmp_ensemble_structure_filename,
ensemble_structure_filename)
else:
# Non-chief workers read the lattices from file.
_poll_for_file(ensemble_structure_filename)
with tf.io.gfile.GFile(
ensemble_structure_filename) as ensemble_structure_file:
model_config.lattices = json.loads(ensemble_structure_file.read())
logging.info('Finalized model structure: %s', str(model_config.lattices))