def get_custom_objects()

in tensorflow_lattice/python/premade.py [0:0]


def get_custom_objects(custom_objects=None):
  """Creates and returns a dictionary mapping names to custom objects.

  Args:
    custom_objects: Optional dictionary mapping names (strings) to custom
      classes or functions to be considered during deserialization. If provided,
      the returned mapping will be extended to contain this one.

  Returns:
    A dictionary mapping names (strings) to tensorflow lattice custom objects.
  """
  tfl_custom_objects = {
      'AggregateFunction':
          AggregateFunction,
      'AggregateFunctionConfig':
          configs.AggregateFunctionConfig,
      'Aggregation':
          aggregation_layer.Aggregation,
      'BiasInitializer':
          kfll.BiasInitializer,
      'CalibratedLatticeEnsemble':
          CalibratedLatticeEnsemble,
      'CalibratedLattice':
          CalibratedLattice,
      'CalibratedLatticeConfig':
          configs.CalibratedLatticeConfig,
      'CalibratedLatticeEnsembleConfig':
          configs.CalibratedLatticeEnsembleConfig,
      'CalibratedLinear':
          CalibratedLinear,
      'CalibratedLinearConfig':
          configs.CalibratedLinearConfig,
      'CategoricalCalibration':
          categorical_calibration_layer.CategoricalCalibration,
      'CategoricalCalibrationConstraints':
          categorical_calibration_layer.CategoricalCalibrationConstraints,
      'DominanceConfig':
          configs.DominanceConfig,
      'FeatureConfig':
          configs.FeatureConfig,
      'KFLRandomMonotonicInitializer':
          kfll.KFLRandomMonotonicInitializer,
      'KroneckerFactoredLattice':
          kfll.KroneckerFactoredLattice,
      'KroneckerFactoredLatticeConstraints':
          kfll.KroneckerFactoredLatticeConstraints,
      'LaplacianRegularizer':
          lattice_layer.LaplacianRegularizer,
      'Lattice':
          lattice_layer.Lattice,
      'LatticeConstraints':
          lattice_layer.LatticeConstraints,
      'Linear':
          linear_layer.Linear,
      'LinearConstraints':
          linear_layer.LinearConstraints,
      'LinearInitializer':
          lattice_layer.LinearInitializer,
      'NaiveBoundsConstraints':
          pwl_calibration_layer.NaiveBoundsConstraints,
      'ParallelCombination':
          parallel_combination_layer.ParallelCombination,
      'PWLCalibration':
          pwl_calibration_layer.PWLCalibration,
      'PWLCalibrationConstraints':
          pwl_calibration_layer.PWLCalibrationConstraints,
      'RandomMonotonicInitializer':
          lattice_layer.RandomMonotonicInitializer,
      'RegularizerConfig':
          configs.RegularizerConfig,
      'RTL':
          rtl_layer.RTL,
      'ScaleConstraints':
          kfll.ScaleConstraints,
      'ScaleInitializer':
          kfll.ScaleInitializer,
      'TorsionRegularizer':
          lattice_layer.TorsionRegularizer,
      'TrustConfig':
          configs.TrustConfig,
  }
  if custom_objects is not None:
    tfl_custom_objects.update(custom_objects)
  return tfl_custom_objects