in tensorflow_fold/loom/loom.py [0:0]
def _setup_network(self):
"""Build the TensorFlow network that can emulate Loom graphs."""
if self._dry_run:
self._output = [tf.constant(np.zeros((1,)+ts.shape, dtype=ts.dtype))
for ts in self._type_shapes]
return
if self._direct_feed_dict:
self._arg_wiring_concat = tf.placeholder(
TENSOR_IDX_T, name='arg_wiring_concat')
self._arg_wiring_slice_starts = tf.placeholder(
TENSOR_IDX_T, name='arg_wiring_slice_starts')
self._arg_wiring_slice_sizes = tf.placeholder(
TENSOR_IDX_T, name='arg_wiring_slice_sizes')
self._output_wirings = [
tf.placeholder(TENSOR_IDX_T, name='output_wirings_%d' % ts_idx)
for ts_idx in xrange(len(self._type_shapes))]
self._constants = [
tf.placeholder(ts.dtype, name='constants_%d' % ts_idx)
for ts_idx, ts in enumerate(self._type_shapes)]
else:
# See REGISTER_WEAVER_OP in weaver_op_base.h for the definitions of the
# outputs in the destructuring assignment below.
(self._arg_wiring_concat,
self._arg_wiring_slice_starts,
self._arg_wiring_slice_sizes,
self._output_wirings,
self._constants) = self._weaver_op(
metadata=self._loom_metadata_str,
constant_types=[tf.as_dtype(ts.dtype) for ts in self._type_shapes],
num_type_shapes=len(self._type_shapes))
# _arg_wiring_concat: an integer vector Tensor containing all the wirings
# for the current schedule concatenated together. They are sorted
# lexically, by (depth, op_idx, arg_idx). This means that
# _arg_wiring_concat consists of max_depth*self._loom_total_args, vectors
# concatenated together. (Here max_depth refers to the final max_depth of
# the emulated graph, not -1 in the event that the Loom was instantiated
# with a while_loop.)
#
# _arg_wiring_slice_starts and _arg_wiring_slice_sizes: these are integer
# vector Tensors of length max_depth*self._loom_total_args that specify how
# to split _arg_wiring_concat back apart into wirings for each (depth,
# op_idx, arg_idx).
#
# The rationale for concatenating all the wiring diagrams together
# like this is that in order to support tf.while_loop, we need to create a
# tensor which produces the appropriate wiring diagram in a way that depends
# on the current depth (this is accomplished using tf.slice in
# _construct_loom_layer.)
#
# _output_wirings: A list of integer vector Tensors, one for each TypeShape.
# These vectors select which elements of the final state tensor end up in
# the Loom's `output_tensor`s.
#
# _constants: A list of Tensors, one for each TypeShape. Each of these
# Tensors should have the dtype of the corresponding TypeShape. The
# contents should be the stacked set of constants declared for that
# TypeShape.
# For each TypeShape, if it's in batched input mode, we use the user
# provided tensor as the input. Otherwise, we take the constants from the
# weaver.
inputs = self._constants
for ts_idx, ts in enumerate(self._type_shapes):
if ts in self._batch_inputs:
inputs[ts_idx] = self._batch_inputs[ts]
# iteration of building up the graph, state will contain tensors
# whose rows will be the objects passed from each depth to the next one of
# the appropriate shapes.
state = []
for inputs_tensor, named_tensors in (
zip(inputs, self._ts_idx_to_named_tensors)):
if not named_tensors:
state.append(inputs_tensor)
else:
state.append(tf.concat([tf.stack(named_tensors), inputs_tensor], 0))
# This block builds up the static graph that consumes Loom's wiring
# diagrams and emulates the dynamic network.
#
# Note: the code that computes wiring diagrams lives in scheduler.cc for
# efficiency reasons.
if self._max_depth == -1: # For dynamic max_depth we use tf.while.
current_max_depth = (
tf.size(self._arg_wiring_slice_starts) // self._loom_total_args)
def loop_conditional(depth, *unused_state):
return tf.less_equal(depth, current_max_depth)
def loop_body(depth, *state):
new_depth = tf.add(depth, 1, name='increment_depth')
new_state = self._construct_loom_layer(depth, state)
return [new_depth] + new_state
initial_depth = tf.constant(1, name='initial_depth')
state = tf.while_loop(loop_conditional, loop_body,
[initial_depth] + state,
parallel_iterations=self._parallel_iterations,
back_prop=self._back_prop,
swap_memory=self._swap_memory)[1:]
else: # For explicit max_depth we unroll the loop.
for depth in xrange(1, self._max_depth+1):
with tf.name_scope('loom_depth_%03d' % depth):
state = self._construct_loom_layer(depth, state)
# _output: The output tensors of the loom, indexed by TypeShape.
with tf.name_scope('output_gathers'):
self._output = [
tf.gather(s, w, name=self._type_shapes[ts_idx].tensor_flow_name())
for ts_idx, (s, w) in enumerate(zip(state, self._output_wirings))]
# Make sure the output tensors know what shape they're supposed to be.
for type_shape, output in zip(self._type_shapes, self._output):
output.set_shape((None,) + type_shape.shape)