def _finalize_model_structure()

in tensorflow_lattice/python/estimators.py [0:0]


def _finalize_model_structure(model_config, label_dimension, feature_columns,
                              head, prefitting_input_fn, prefitting_optimizer,
                              prefitting_steps, model_dir, config,
                              warm_start_from, dtype):
  """Sets up the lattice ensemble in model_config with requested algorithm."""
  if (not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or
      isinstance(model_config.lattices, list)):
    return

  # TODO: If warmstarting, look for the previous ensemble file.
  if warm_start_from:
    raise ValueError('Warm starting lattice ensembles without explicitly '
                     'defined lattices is not supported yet.')

  if feature_columns:
    feature_names = [feature_column.name for feature_column in feature_columns]
  else:
    feature_names = [
        feature_config.name for feature_config in model_config.feature_configs
    ]

  if model_config.lattice_rank > len(feature_names):
    raise ValueError(
        'lattice_rank {} cannot be larger than the number of features: {}'
        .format(model_config.lattice_rank, feature_names))

  if model_config.num_lattices * model_config.lattice_rank < len(feature_names):
    raise ValueError(
        'Model with {}x{}d lattices is not large enough for all features: {}'
        .format(model_config.num_lattices, model_config.lattice_rank,
                feature_names))

  ensemble_structure_filename = os.path.join(model_dir,
                                             _ENSEMBLE_STRUCTURE_FILE)
  if ((config is None or config.is_chief) and
      not tf.io.gfile.exists(ensemble_structure_filename)):
    if model_config.lattices not in ['random', 'crystals', 'rtl_layer']:
      raise ValueError('Unsupported ensemble structure: {}'.format(
          model_config.lattices))
    if model_config.lattices == 'random':
      premade_lib.set_random_lattice_ensemble(model_config, feature_names)
    elif model_config.lattices == 'crystals':
      _set_crystals_lattice_ensemble(
          feature_names=feature_names,
          label_dimension=label_dimension,
          feature_columns=feature_columns,
          head=head,
          model_config=model_config,
          prefitting_input_fn=prefitting_input_fn,
          prefitting_optimizer=prefitting_optimizer,
          prefitting_steps=prefitting_steps,
          config=config,
          dtype=dtype)
    if (model_config.fix_ensemble_for_2d_constraints and
        model_config.lattices != 'rtl_layer'):
      # Note that we currently only support monotonicity and bound constraints
      # for RTL.
      _fix_ensemble_for_2d_constraints(model_config, feature_names)

    # Save lattices to file as the chief worker.
    tmp_ensemble_structure_filename = ensemble_structure_filename + 'tmp'
    with tf.io.gfile.GFile(tmp_ensemble_structure_filename,
                           'w') as ensemble_structure_file:
      ensemble_structure_file.write(json.dumps(model_config.lattices, indent=2))
    tf.io.gfile.rename(tmp_ensemble_structure_filename,
                       ensemble_structure_filename)
  else:
    # Non-chief workers read the lattices from file.
    _poll_for_file(ensemble_structure_filename)
    with tf.io.gfile.GFile(
        ensemble_structure_filename) as ensemble_structure_file:
      model_config.lattices = json.loads(ensemble_structure_file.read())

  logging.info('Finalized model structure: %s', str(model_config.lattices))