in tensorflow_lattice/python/estimators.py [0:0]
def _create_rtl_submodel_kronecker_factored_lattice_nodes(
sess, ops, graph, flattened_calibration_nodes, submodel_idx, submodel_key):
"""Returns next key and map from key+unit to KroneckerFactoredLatticeNode."""
submodel_kfl_nodes = {}
# KroneckerFactoredLattice kernel weights
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{KFL_KERNEL_NAME}
kfl_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
kfll.KFL_KERNEL_NAME,
)
for kfl_kernel_op, monotonicities in _match_op(ops, kfl_kernel_op_re):
kfl_kernel = sess.run(
graph.get_operation_by_name(kfl_kernel_op).outputs[0]).flatten()
# KroneckerFactoredLattice scale.
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{KFL_SCALE_NAME}
kfl_scale_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
monotonicities,
kfll.KFL_SCALE_NAME,
)
kfl_scale = sess.run(
graph.get_operation_by_name(kfl_scale_op_name).outputs[0]).flatten()
# KroneckerFactoredLattice bias.
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{KFL_BIAS_NAME}
kfl_bias_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
monotonicities,
kfll.KFL_BIAS_NAME,
)
kfl_bias = sess.run(
graph.get_operation_by_name(kfl_bias_op_name).outputs[0]).flatten()
# Lattice sizes.
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{LATTICE_SIZES_NAME}
lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
monotonicities,
kfll.LATTICE_SIZES_NAME,
)
lattice_sizes = sess.run(
graph.get_operation_by_name(lattice_sizes_op_name).outputs[0])
# Dims.
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{DIMS_NAME}
dims_op_name = '{}_{}/{}_{}/{}'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
monotonicities,
kfll.DIMS_NAME,
)
dims = sess.run(graph.get_operation_by_name(dims_op_name).outputs[0])
# Num terms.
# {RTL_LAYER_NAME}_{submodel_idx}/
# {RTL_KFL_NAME}_{monotonicities}/{NUM_TERMS_NAME}
num_terms_op_name = '{}_{}/{}_{}/{}'.format(
premade_lib.RTL_LAYER_NAME,
submodel_idx,
rtl_layer.RTL_KFL_NAME,
monotonicities,
kfll.NUM_TERMS_NAME,
)
num_terms = sess.run(
graph.get_operation_by_name(num_terms_op_name).outputs[0])
# inputs_for_units
# {RTL_LAYER_NAME}_{submodel_index}/
# {INPUTS_FOR_UNITS_PREFIX}_{monotonicities}
inputs_for_units_op_name = '{}_{}/{}_{}'.format(
premade_lib.RTL_LAYER_NAME, submodel_idx,
rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities)
inputs_for_units = sess.run(
graph.get_operation_by_name(inputs_for_units_op_name).outputs[0])
# Make a unique kfl for each unit.
units = inputs_for_units.shape[0]
for i in range(units):
# Shape the flat weights, scale, and bias parameters based on the
# calculated lattice_sizes, units, dims, and num_terms.
weights = np.reshape(kfl_kernel,
(1, lattice_sizes, units * dims, num_terms))
scale = np.reshape(kfl_scale, (units, num_terms))
bias = np.reshape(kfl_bias, (units))
# Gather input nodes for lattice node.
indices = inputs_for_units[i]
input_nodes = [flattened_calibration_nodes[index] for index in indices]
kfl_node = model_info.KroneckerFactoredLatticeNode(
input_nodes=input_nodes, weights=weights, scale=scale, bias=bias)
submodel_kfl_nodes[submodel_key] = kfl_node
submodel_key += 1
return submodel_key, submodel_kfl_nodes