def _setup_network()

in tensorflow_fold/loom/loom.py [0:0]


  def _setup_network(self):
    """Build the TensorFlow network that can emulate Loom graphs."""
    if self._dry_run:
      self._output = [tf.constant(np.zeros((1,)+ts.shape, dtype=ts.dtype))
                      for ts in self._type_shapes]
      return

    if self._direct_feed_dict:
      self._arg_wiring_concat = tf.placeholder(
          TENSOR_IDX_T, name='arg_wiring_concat')
      self._arg_wiring_slice_starts = tf.placeholder(
          TENSOR_IDX_T, name='arg_wiring_slice_starts')
      self._arg_wiring_slice_sizes = tf.placeholder(
          TENSOR_IDX_T, name='arg_wiring_slice_sizes')
      self._output_wirings = [
          tf.placeholder(TENSOR_IDX_T, name='output_wirings_%d' % ts_idx)
          for ts_idx in xrange(len(self._type_shapes))]
      self._constants = [
          tf.placeholder(ts.dtype, name='constants_%d' % ts_idx)
          for ts_idx, ts in enumerate(self._type_shapes)]
    else:
      # See REGISTER_WEAVER_OP in weaver_op_base.h for the definitions of the
      # outputs in the destructuring assignment below.
      (self._arg_wiring_concat,
       self._arg_wiring_slice_starts,
       self._arg_wiring_slice_sizes,
       self._output_wirings,
       self._constants) = self._weaver_op(
           metadata=self._loom_metadata_str,
           constant_types=[tf.as_dtype(ts.dtype) for ts in self._type_shapes],
           num_type_shapes=len(self._type_shapes))
    # _arg_wiring_concat: an integer vector Tensor containing all the wirings
    # for the current schedule concatenated together.  They are sorted
    # lexically, by (depth, op_idx, arg_idx).  This means that
    # _arg_wiring_concat consists of max_depth*self._loom_total_args, vectors
    # concatenated together.  (Here max_depth refers to the final max_depth of
    # the emulated graph, not -1 in the event that the Loom was instantiated
    # with a while_loop.)
    #
    # _arg_wiring_slice_starts and _arg_wiring_slice_sizes: these are integer
    # vector Tensors of length max_depth*self._loom_total_args that specify how
    # to split _arg_wiring_concat back apart into wirings for each (depth,
    # op_idx, arg_idx).
    #
    # The rationale for concatenating all the wiring diagrams together
    # like this is that in order to support tf.while_loop, we need to create a
    # tensor which produces the appropriate wiring diagram in a way that depends
    # on the current depth (this is accomplished using tf.slice in
    # _construct_loom_layer.)
    #
    # _output_wirings: A list of integer vector Tensors, one for each TypeShape.
    # These vectors select which elements of the final state tensor end up in
    # the Loom's `output_tensor`s.
    #
    # _constants: A list of Tensors, one for each TypeShape.  Each of these
    # Tensors should have the dtype of the corresponding TypeShape.  The
    # contents should be the stacked set of constants declared for that
    # TypeShape.

    # For each TypeShape, if it's in batched input mode, we use the user
    # provided tensor as the input.  Otherwise, we take the constants from the
    # weaver.
    inputs = self._constants
    for ts_idx, ts in enumerate(self._type_shapes):
      if ts in self._batch_inputs:
        inputs[ts_idx] = self._batch_inputs[ts]

    # iteration of building up the graph, state will contain tensors
    # whose rows will be the objects passed from each depth to the next one of
    # the appropriate shapes.
    state = []
    for inputs_tensor, named_tensors in (
        zip(inputs, self._ts_idx_to_named_tensors)):
      if not named_tensors:
        state.append(inputs_tensor)
      else:
        state.append(tf.concat([tf.stack(named_tensors), inputs_tensor], 0))

    # This block builds up the static graph that consumes Loom's wiring
    # diagrams and emulates the dynamic network.
    #
    # Note: the code that computes wiring diagrams lives in scheduler.cc for
    # efficiency reasons.
    if self._max_depth == -1:  # For dynamic max_depth we use tf.while.
      current_max_depth = (
          tf.size(self._arg_wiring_slice_starts) // self._loom_total_args)
      def loop_conditional(depth, *unused_state):
        return tf.less_equal(depth, current_max_depth)
      def loop_body(depth, *state):
        new_depth = tf.add(depth, 1, name='increment_depth')
        new_state = self._construct_loom_layer(depth, state)
        return [new_depth] + new_state
      initial_depth = tf.constant(1, name='initial_depth')
      state = tf.while_loop(loop_conditional, loop_body,
                            [initial_depth] + state,
                            parallel_iterations=self._parallel_iterations,
                            back_prop=self._back_prop,
                            swap_memory=self._swap_memory)[1:]
    else:  # For explicit max_depth we unroll the loop.
      for depth in xrange(1, self._max_depth+1):
        with tf.name_scope('loom_depth_%03d' % depth):
          state = self._construct_loom_layer(depth, state)

    # _output: The output tensors of the loom, indexed by TypeShape.
    with tf.name_scope('output_gathers'):
      self._output = [
          tf.gather(s, w, name=self._type_shapes[ts_idx].tensor_flow_name())
          for ts_idx, (s, w) in enumerate(zip(state, self._output_wirings))]

    # Make sure the output tensors know what shape they're supposed to be.
    for type_shape, output in zip(self._type_shapes, self._output):
      output.set_shape((None,) + type_shape.shape)