in tensorflow_fold/loom/loom.py [0:0]
def _construct_loom_layer(self, depth, state):
"""Builds one unit of the loom's graph.
A Loom unit is a TensorFlow graph that performs all the operations scheduled
on the Loom at a given depth.
Args:
depth: An integer or integer tensor containing the current depth.
state: A list of tensors (one for each TypeShape) which will contain
batches of things of that TypeShape.
Returns:
A list of tensors (one for each TypeShape) which will contain batches of
things of that TypeShape. (The input to the next loom layer.)
Raises:
ValueError: If a LoomOp's instantiate_batch method returns Tensors of the
wrong DataType or shape.
"""
# Segments to be concatenated together to form the output state (indexed by
# TypeShape ID.)
new_state_segments = [[] for _ in state]
# Note: `start_wire_pos` might be a tensor or an integer.
start_wire_pos = (depth - 1) * self._loom_total_args
wire_pos_offset = 0 # `wire_pos_offset` is an integer.
for op_idx, op in enumerate(self._loom_ops):
with tf.name_scope(self._loom_op_names[op_idx]):
arg_inputs = []
for arg_idx, arg_ts in enumerate(op.input_type_shapes):
with tf.name_scope('arg_%d' % arg_idx):
# wire_pos: a tensor or integer specifying which argument's wiring
# diagram we wish to extract from `arg_wiring_concat`
wire_pos = start_wire_pos + wire_pos_offset
wire_pos_offset += 1
# slice_start: a vector of length 1 containing the starting postion
# starting postion of this argument's wiring in arg_wiring_concat.
slice_start = tf.slice(
self._arg_wiring_slice_starts, [wire_pos], [1])
# slice_size: a vector of length 1 containing the starting postion
# starting postion of this argument's wiring in arg_wiring_concat.
slice_size = tf.slice(
self._arg_wiring_slice_sizes, [wire_pos], [1])
# arg_wiring: a tensor specifying the indices the of several tensors
# (within the state vector corresponding to the TypeShape of arg).
# This batch of tensors get will be passed to argument `arg_idx` of
# op `op` at depth `depth`.
#
# The contents of this tensor will be the same as the vector
# computed by Weaver::GetWiring(depth, op_idx, arg_idx) in C++.
arg_wiring = tf.slice(
self._arg_wiring_concat, slice_start, slice_size)
arg_ts_idx = self._type_shape_to_idx[arg_ts]
# This tf.gather constructs sub-layer (1) of the loom layer.
# (See the class doc-string section on Implementation Details)
#
# This gather selects which batch of tensors get passed to argument
# `arg_idx` of op `op` at depth `depth`.
arg_input = tf.gather(state[arg_ts_idx], arg_wiring)
# We sure the inputs are tagged with the correct shape before
# attempting to concatenate.
arg_input.set_shape((None,) + arg_ts.shape)
arg_inputs.append(arg_input)
# This call to op.instantiate_batch constructs sub-layer (2) of the loom
# layer.
op_outputs = op.instantiate_batch(arg_inputs)
for output_idx, (output, output_ts) in enumerate(
zip(op_outputs, op.output_type_shapes)):
# Do some sanity checking to make sure instantiate_batch output
# Tensors of the right type and shape.
if not isinstance(output, tf.Tensor):
raise TypeError('Op %s returns non-Tensor output %r' %
(self._loom_op_names[op_idx], output))
try:
output.set_shape((None,) + output_ts.shape) # Check shape.
except ValueError as e:
raise ValueError('Op %s output %d: %s' % (
self._loom_op_names[op_idx], output_idx, e))
if output.dtype.base_dtype.name != output_ts.dtype:
raise ValueError('Op %s output %d: expected dtype %s got %s' % (
self._loom_op_names[op_idx], output_idx,
output_ts.dtype, output.dtype.base_dtype.name))
# Append this output of the arg to the list of segments of the
# appropriate typeshape.
#
# Note: The segments of a given typeshape will end up sorted lexically
# by (op_idx, op_output_idx). weaver.cc depends on this fact when
# computing offsets in order to transform its graph into a wiring
# diagram (See Weaver::Finalize)
output_ts_idx = self._type_shape_to_idx[output_ts]
new_state_segments[output_ts_idx].append(output)
with tf.name_scope('concat'):
# This concat constructs sub-layer (3) of the loom layer.
#
# We need to concatenate all the outputs of the same type-shape
# together so that the next layer can gather over them.
# This allows any LoomOp with an input of some type_shape to get its
# input from any output of any LoomOp (provided it is of the same
# TypeShape.)
return [
tf.concat(
s, 0, name=self._type_shapes[ts_idx].tensor_flow_name())
for ts_idx, s in enumerate(new_state_segments)
]