in tensorflow_fold/loom/loom.py [0:0]
def __init__(self, max_depth=None, named_tensors=None, named_ops=None,
batch_inputs=None, extra_type_shapes=None, dry_run=False,
parallel_iterations=None, back_prop=None, swap_memory=None,
direct_feed_dict=False, loom_input_tensor=None, weaver_op=None):
"""Constructs a Loom.
While this constructor has many arguments, the only arguments most users
will care about are `named_ops`, `named_tensors`, `dry_run`,
`loom_input_tensor` and possibly `weaver_op`.
To create a Loom object, the only mandatory argument is `named_ops` (a
dictionary mapping strings to `LoomOps`) specifying the collection of
operations the Loom should support.
Specifiying `named_tensors` allows the `Loom` to construct graphs that refer
to the provided TensorFlow tensors. The advantage of using a named tensor
instead of a Loom constant is that the named tensor can be backpropped
through.
Specifying `loom_input_tensor` causes the `Loom` to read its schedules
(`WeaverMessages`) from external sources. Specifying `weaver_op` allows
`Loom` to compute them on the fly in C++. See the class docstring section
named "Bypass Modes" for the motivation for this feature.
Specifying `dry_run` creates the Loom without constructing the associated
TensorFlow graph. This is useful when the loom is only going to be used to
construct `WeaverMessages` to drive another instance of the same loom.
Args:
max_depth: An optional integer depth to unroll the generic network to. If
absent, Loom uses a `tf.while_loop`. `max_depth` is provided for
compatibility with old versions of TensorFlow with bad support for
`tf.while_loop` and for debugging purposes.
named_tensors: An optional dictionary mapping strings to Tensors. (Named
tensors are effectively zero argument LoomOps.) Each value of
`named_tensors` must be either a tf.Tensor or a tuple of the form
(tf.Tensor, str) with the string specifying a TypeShape tag.
named_ops: A mandatory dictionary mapping strings to LoomOp objects (the
set of operations the Loom should support.)
batch_inputs: An optional dictionary mapping TypeShapes to Tensors. Each
Tensor in the dictionary should have the type and shape to contain a
batch of things of that TypeShape stacked along dimension 0.
extra_type_shapes: An optional iterable containing extra TypeShapes that
may not be inputs or outputs of LoomOps but that the Loom should support
anyway.
dry_run: Boolean. If true, don't build the TensorFlow graph (and make the
output tensors be dummy constants.) This is useful for rapid testing in
situtions where building the TensorFlow graph is expensive (eg. large
max_depth) or when the objective is to construct schedules and serialize
them as `WeaverMessages` for later use.
parallel_iterations: Integer. tf.while_loop's parallel_iterations option,
which caps the number of different depths at which ops could run in
parallel. Only applies when max_depth=None. Default: 10.
back_prop: Boolean. tf.while_loop's back_prop option, which enables
gradients. Only applies when max_depth=None. Default: True.
swap_memory: Boolean. Whether to use tf.while_loop's swap_memory option,
which enables swapping memory between GPU and CPU at the possible
expense of some performance. Only applies when max_depth=None. Default:
False.
direct_feed_dict: Boolean. If true, this loom doesn't create a loom_input
tensor for WeaverMessages, and instead creates placeholders for the
wiring diagrams. Default: False.
loom_input_tensor: An optional string Tensor from which to read
WeaverMessages which specify how to wire the loom. If more than one is
present they will be merged (auto-merge is provided so that
WeaverMessages for individual inputs can be cached in advance while
still using random mini-batches at run-time.) Mutally exclusive with
`weaver_op`.
weaver_op: An optional callable which constructs a TensorFlow op to
produce inputs for the loom. Mutually exclusive with
`loom_input_tensor`. If absent, the loom acts as though `weaver_op`
were a function creating a `deserializing_weaver` op which consumes
`WeaverMessages` from `loom_input_tensor`. The callable will be called
with three keyword arguments named `metadata`, `constant_types`, and
`num_type_shapes` (because these are the three attributes any op
descending from `WeaverOpBase` requires to be instantiated.)
Raises:
TypeError: If `named_ops` is not provided.
TypeError: If more than one tagged TypeShape has the same tag.
"""
if named_ops is None:
raise TypeError('named_ops is a mandatory argument.')
# max_depth is going to be put into the LoomMetadata proto which uses the
# special value -1 to indicate that Loom's TensorFlow graph will be
# constructed using a while loop (and therefore have no fixed maximum
# depth.)
if max_depth is None: max_depth = -1
# _max_depth: the maximum operation depth supported by the loom (or -1 for
# while loop.)
#
# If _max_depth is not -1, it is the maximum nesting depth for the graph
# loom can emulate. For example f(f(c)) (where c is a constant and f is an
# loom operation) would be allowed if _max_depth is 2 but not if _max_depth
# is 1.
self._max_depth = max_depth
if named_tensors is None: named_tensors = {}
if batch_inputs is None: batch_inputs = {}
if parallel_iterations is None: parallel_iterations = 10
if back_prop is None: back_prop = True
if swap_memory is None: swap_memory = False
# _batch_inputs: a dictionary mapping typeshapes to tensors containing
# batches of that typeshape to be used as inputs.
self._batch_inputs = batch_inputs
# _dry_run: if true don't build the TF graph (all output tensors get
# replaced with one set of zeros.)
self._dry_run = dry_run
# _parallel_iterations, _back_prop, _swap_memory: options for tf.while_loop.
self._parallel_iterations = parallel_iterations
self._back_prop = back_prop
self._swap_memory = swap_memory
# _direct_feed_dict: a bool specifying whether to construct a graph which
# bypasses the deserializing_weaver_op.
self._direct_feed_dict = direct_feed_dict
if direct_feed_dict:
if loom_input_tensor is not None:
raise TypeError(
'direct_feed_dict and loom_input_tensor are incompatible.')
if weaver_op is not None:
raise TypeError('direct_feed_dict and weaver_op are incompatible.')
# _loom_input_tensor: a tensor which ought to hold a single serialized
# WeaverMessage specifying the loom's wiring diagram.
if not direct_feed_dict:
if weaver_op is None:
if loom_input_tensor is None:
loom_input_tensor = tf.placeholder(
'string', name='LoomInput')
def weaver_from_input_tensor(**kwargs):
return deserializing_weaver_op.deserializing_weaver(
self._loom_input_tensor, **kwargs)
weaver_op = weaver_from_input_tensor
else:
if loom_input_tensor is not None:
raise TypeError('You can specify at most one of loom_input_tensor '
'or weaver_op.')
self._loom_input_tensor = loom_input_tensor
self._weaver_op = weaver_op
self._setup_type_shapes(named_ops, extra_type_shapes)
self._setup_named_tensors(named_tensors)
self._setup_loom_ops(named_ops)
self._setup_metadata()
self._setup_network()