def _create_rtl_submodel_kronecker_factored_lattice_nodes()

in tensorflow_lattice/python/estimators.py [0:0]


def _create_rtl_submodel_kronecker_factored_lattice_nodes(
    sess, ops, graph, flattened_calibration_nodes, submodel_idx, submodel_key):
  """Returns next key and map from key+unit to KroneckerFactoredLatticeNode."""
  submodel_kfl_nodes = {}
  # KroneckerFactoredLattice kernel weights
  # {RTL_LAYER_NAME}_{submodel_idx}/
  # {RTL_KFL_NAME}_{monotonicities}/{KFL_KERNEL_NAME}
  kfl_kernel_op_re = '^{}_{}/{}_(.*)/{}/Read/ReadVariableOp$'.format(
      premade_lib.RTL_LAYER_NAME,
      submodel_idx,
      rtl_layer.RTL_KFL_NAME,
      kfll.KFL_KERNEL_NAME,
  )
  for kfl_kernel_op, monotonicities in _match_op(ops, kfl_kernel_op_re):
    kfl_kernel = sess.run(
        graph.get_operation_by_name(kfl_kernel_op).outputs[0]).flatten()

    # KroneckerFactoredLattice scale.
    # {RTL_LAYER_NAME}_{submodel_idx}/
    # {RTL_KFL_NAME}_{monotonicities}/{KFL_SCALE_NAME}
    kfl_scale_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format(
        premade_lib.RTL_LAYER_NAME,
        submodel_idx,
        rtl_layer.RTL_KFL_NAME,
        monotonicities,
        kfll.KFL_SCALE_NAME,
    )
    kfl_scale = sess.run(
        graph.get_operation_by_name(kfl_scale_op_name).outputs[0]).flatten()

    # KroneckerFactoredLattice bias.
    # {RTL_LAYER_NAME}_{submodel_idx}/
    # {RTL_KFL_NAME}_{monotonicities}/{KFL_BIAS_NAME}
    kfl_bias_op_name = '{}_{}/{}_{}/{}/Read/ReadVariableOp'.format(
        premade_lib.RTL_LAYER_NAME,
        submodel_idx,
        rtl_layer.RTL_KFL_NAME,
        monotonicities,
        kfll.KFL_BIAS_NAME,
    )
    kfl_bias = sess.run(
        graph.get_operation_by_name(kfl_bias_op_name).outputs[0]).flatten()

    # Lattice sizes.
    # {RTL_LAYER_NAME}_{submodel_idx}/
    # {RTL_KFL_NAME}_{monotonicities}/{LATTICE_SIZES_NAME}
    lattice_sizes_op_name = '{}_{}/{}_{}/{}'.format(
        premade_lib.RTL_LAYER_NAME,
        submodel_idx,
        rtl_layer.RTL_KFL_NAME,
        monotonicities,
        kfll.LATTICE_SIZES_NAME,
    )
    lattice_sizes = sess.run(
        graph.get_operation_by_name(lattice_sizes_op_name).outputs[0])

    # Dims.
    # {RTL_LAYER_NAME}_{submodel_idx}/
    # {RTL_KFL_NAME}_{monotonicities}/{DIMS_NAME}
    dims_op_name = '{}_{}/{}_{}/{}'.format(
        premade_lib.RTL_LAYER_NAME,
        submodel_idx,
        rtl_layer.RTL_KFL_NAME,
        monotonicities,
        kfll.DIMS_NAME,
    )
    dims = sess.run(graph.get_operation_by_name(dims_op_name).outputs[0])

    # Num terms.
    # {RTL_LAYER_NAME}_{submodel_idx}/
    # {RTL_KFL_NAME}_{monotonicities}/{NUM_TERMS_NAME}
    num_terms_op_name = '{}_{}/{}_{}/{}'.format(
        premade_lib.RTL_LAYER_NAME,
        submodel_idx,
        rtl_layer.RTL_KFL_NAME,
        monotonicities,
        kfll.NUM_TERMS_NAME,
    )
    num_terms = sess.run(
        graph.get_operation_by_name(num_terms_op_name).outputs[0])

    # inputs_for_units
    # {RTL_LAYER_NAME}_{submodel_index}/
    # {INPUTS_FOR_UNITS_PREFIX}_{monotonicities}
    inputs_for_units_op_name = '{}_{}/{}_{}'.format(
        premade_lib.RTL_LAYER_NAME, submodel_idx,
        rtl_layer.INPUTS_FOR_UNITS_PREFIX, monotonicities)
    inputs_for_units = sess.run(
        graph.get_operation_by_name(inputs_for_units_op_name).outputs[0])

    # Make a unique kfl for each unit.
    units = inputs_for_units.shape[0]
    for i in range(units):
      # Shape the flat weights, scale, and bias parameters based on the
      # calculated lattice_sizes, units, dims, and num_terms.
      weights = np.reshape(kfl_kernel,
                           (1, lattice_sizes, units * dims, num_terms))
      scale = np.reshape(kfl_scale, (units, num_terms))
      bias = np.reshape(kfl_bias, (units))

      # Gather input nodes for lattice node.
      indices = inputs_for_units[i]
      input_nodes = [flattened_calibration_nodes[index] for index in indices]

      kfl_node = model_info.KroneckerFactoredLatticeNode(
          input_nodes=input_nodes, weights=weights, scale=scale, bias=bias)
      submodel_kfl_nodes[submodel_key] = kfl_node
      submodel_key += 1
  return submodel_key, submodel_kfl_nodes