in tensorflow_lattice/python/premade.py [0:0]
def get_custom_objects(custom_objects=None):
"""Creates and returns a dictionary mapping names to custom objects.
Args:
custom_objects: Optional dictionary mapping names (strings) to custom
classes or functions to be considered during deserialization. If provided,
the returned mapping will be extended to contain this one.
Returns:
A dictionary mapping names (strings) to tensorflow lattice custom objects.
"""
tfl_custom_objects = {
'AggregateFunction':
AggregateFunction,
'AggregateFunctionConfig':
configs.AggregateFunctionConfig,
'Aggregation':
aggregation_layer.Aggregation,
'BiasInitializer':
kfll.BiasInitializer,
'CalibratedLatticeEnsemble':
CalibratedLatticeEnsemble,
'CalibratedLattice':
CalibratedLattice,
'CalibratedLatticeConfig':
configs.CalibratedLatticeConfig,
'CalibratedLatticeEnsembleConfig':
configs.CalibratedLatticeEnsembleConfig,
'CalibratedLinear':
CalibratedLinear,
'CalibratedLinearConfig':
configs.CalibratedLinearConfig,
'CategoricalCalibration':
categorical_calibration_layer.CategoricalCalibration,
'CategoricalCalibrationConstraints':
categorical_calibration_layer.CategoricalCalibrationConstraints,
'DominanceConfig':
configs.DominanceConfig,
'FeatureConfig':
configs.FeatureConfig,
'KFLRandomMonotonicInitializer':
kfll.KFLRandomMonotonicInitializer,
'KroneckerFactoredLattice':
kfll.KroneckerFactoredLattice,
'KroneckerFactoredLatticeConstraints':
kfll.KroneckerFactoredLatticeConstraints,
'LaplacianRegularizer':
lattice_layer.LaplacianRegularizer,
'Lattice':
lattice_layer.Lattice,
'LatticeConstraints':
lattice_layer.LatticeConstraints,
'Linear':
linear_layer.Linear,
'LinearConstraints':
linear_layer.LinearConstraints,
'LinearInitializer':
lattice_layer.LinearInitializer,
'NaiveBoundsConstraints':
pwl_calibration_layer.NaiveBoundsConstraints,
'ParallelCombination':
parallel_combination_layer.ParallelCombination,
'PWLCalibration':
pwl_calibration_layer.PWLCalibration,
'PWLCalibrationConstraints':
pwl_calibration_layer.PWLCalibrationConstraints,
'RandomMonotonicInitializer':
lattice_layer.RandomMonotonicInitializer,
'RegularizerConfig':
configs.RegularizerConfig,
'RTL':
rtl_layer.RTL,
'ScaleConstraints':
kfll.ScaleConstraints,
'ScaleInitializer':
kfll.ScaleInitializer,
'TorsionRegularizer':
lattice_layer.TorsionRegularizer,
'TrustConfig':
configs.TrustConfig,
}
if custom_objects is not None:
tfl_custom_objects.update(custom_objects)
return tfl_custom_objects