def _construct_loom_layer()

in tensorflow_fold/loom/loom.py [0:0]


  def _construct_loom_layer(self, depth, state):
    """Builds one unit of the loom's graph.

    A Loom unit is a TensorFlow graph that performs all the operations scheduled
    on the Loom at a given depth.

    Args:
      depth: An integer or integer tensor containing the current depth.
      state: A list of tensors (one for each TypeShape) which will contain
      batches of things of that TypeShape.

    Returns:
      A list of tensors (one for each TypeShape) which will contain batches of
      things of that TypeShape. (The input to the next loom layer.)

    Raises:
      ValueError: If a LoomOp's instantiate_batch method returns Tensors of the
        wrong DataType or shape.
    """
    # Segments to be concatenated together to form the output state (indexed by
    # TypeShape ID.)
    new_state_segments = [[] for _ in state]

    # Note: `start_wire_pos` might be a tensor or an integer.
    start_wire_pos = (depth - 1) * self._loom_total_args

    wire_pos_offset = 0  # `wire_pos_offset` is an integer.

    for op_idx, op in enumerate(self._loom_ops):
      with tf.name_scope(self._loom_op_names[op_idx]):
        arg_inputs = []
        for arg_idx, arg_ts in enumerate(op.input_type_shapes):
          with tf.name_scope('arg_%d' % arg_idx):
            # wire_pos: a tensor or integer specifying which argument's wiring
            # diagram we wish to extract from `arg_wiring_concat`
            wire_pos = start_wire_pos + wire_pos_offset
            wire_pos_offset += 1

            # slice_start: a vector of length 1 containing the starting postion
            # starting postion of this argument's wiring in arg_wiring_concat.
            slice_start = tf.slice(
                self._arg_wiring_slice_starts, [wire_pos], [1])

            # slice_size: a vector of length 1 containing the starting postion
            # starting postion of this argument's wiring in arg_wiring_concat.
            slice_size = tf.slice(
                self._arg_wiring_slice_sizes, [wire_pos], [1])

            # arg_wiring: a tensor specifying the indices the of several tensors
            # (within the state vector corresponding to the TypeShape of arg).
            # This batch of tensors get will be passed to argument `arg_idx` of
            # op `op` at depth `depth`.
            #
            # The contents of this tensor will be the same as the vector
            # computed by Weaver::GetWiring(depth, op_idx, arg_idx) in C++.
            arg_wiring = tf.slice(
                self._arg_wiring_concat, slice_start, slice_size)

            arg_ts_idx = self._type_shape_to_idx[arg_ts]

            # This tf.gather constructs sub-layer (1) of the loom layer.
            # (See the class doc-string section on Implementation Details)
            #
            # This gather selects which batch of tensors get passed to argument
            # `arg_idx` of op `op` at depth `depth`.
            arg_input = tf.gather(state[arg_ts_idx], arg_wiring)

            # We sure the inputs are tagged with the correct shape before
            # attempting to concatenate.
            arg_input.set_shape((None,) + arg_ts.shape)
            arg_inputs.append(arg_input)

        # This call to op.instantiate_batch constructs sub-layer (2) of the loom
        # layer.
        op_outputs = op.instantiate_batch(arg_inputs)

        for output_idx, (output, output_ts) in enumerate(
            zip(op_outputs, op.output_type_shapes)):
          # Do some sanity checking to make sure instantiate_batch output
          # Tensors of the right type and shape.
          if not isinstance(output, tf.Tensor):
            raise TypeError('Op %s returns non-Tensor output %r' %
                            (self._loom_op_names[op_idx], output))
          try:
            output.set_shape((None,) + output_ts.shape)  # Check shape.
          except ValueError as e:
            raise ValueError('Op %s output %d: %s' % (
                self._loom_op_names[op_idx], output_idx, e))
          if output.dtype.base_dtype.name != output_ts.dtype:
            raise ValueError('Op %s output %d: expected dtype %s got %s' % (
                self._loom_op_names[op_idx], output_idx,
                output_ts.dtype, output.dtype.base_dtype.name))

          # Append this output of the arg to the list of segments of the
          # appropriate typeshape.
          #
          # Note: The segments of a given typeshape will end up sorted lexically
          # by (op_idx, op_output_idx).  weaver.cc depends on this fact when
          # computing offsets in order to transform its graph into a wiring
          # diagram (See Weaver::Finalize)
          output_ts_idx = self._type_shape_to_idx[output_ts]
          new_state_segments[output_ts_idx].append(output)

    with tf.name_scope('concat'):
      # This concat constructs sub-layer (3) of the loom layer.
      #
      # We need to concatenate all the outputs of the same type-shape
      # together so that the next layer can gather over them.
      # This allows any LoomOp with an input of some type_shape to get its
      # input from any output of any LoomOp (provided it is of the same
      # TypeShape.)
      return [
          tf.concat(
              s, 0, name=self._type_shapes[ts_idx].tensor_flow_name())
          for ts_idx, s in enumerate(new_state_segments)
      ]