in tensorflow_fold/loom/loom.py [0:0]
def build_feed_dict(self, outputs=None):
"""Turn this diagram into a dictionary for feed_dict.
Warning: No changes made to this Weaver will be reflected in the
results of `build_feed_dict` after the first time it is called
because `build_feed_dict` calls `Weaver::Finalize`, which freezes
the Weaver's output wirings.
Returns:
A dictionary which can be passed as a `feed_dict` argument to
`tf.Session.run()` which will cause this Weaver's Loom to behave like
the diagram.
Args:
outputs: Additional nodes which should be sent to the output tensors
(these can also be set using `add_output`.)
"""
if self._loom._dry_run:
return {}
if outputs is not None:
for output in outputs:
self.add_output(output)
if self._loom._direct_feed_dict:
self._weaver.Finalize()
arg_wiring_concat = []
arg_wiring_slice_starts = []
arg_wiring_slice_sizes = []
for depth in xrange(1, self._weaver.MaxDepth()+1):
for op_idx, op in enumerate(self._loom._loom_ops):
for arg_idx in xrange(len(op.input_type_shapes)):
arg_wiring_slice_starts.append(len(arg_wiring_concat))
wiring = list(self._weaver.GetWiring(depth, op_idx, arg_idx))
arg_wiring_slice_sizes.append(len(wiring))
arg_wiring_concat.extend(wiring)
feed_dict = {
self._loom._arg_wiring_concat: arg_wiring_concat,
self._loom._arg_wiring_slice_starts: arg_wiring_slice_starts,
self._loom._arg_wiring_slice_sizes: arg_wiring_slice_sizes}
for ts_idx, output_wiring_ph in enumerate(self._loom._output_wirings):
feed_dict[output_wiring_ph] = self._weaver.GetOutputWiring(ts_idx)
for ts_idx, constant_ph in enumerate(self._loom._constants):
constants = np.array(self._constants[ts_idx],
dtype=self._loom._type_shapes[ts_idx].dtype)
if not self._constants[ts_idx]:
constants = np.reshape(
constants, (0,) + self._loom._type_shapes[ts_idx].shape)
feed_dict[constant_ph] = constants
return feed_dict
else:
if self._loom._loom_input_tensor is None:
raise TypeError('You cannot call build_feed_dict on a LoomInput if '
'its Loom has a custom weaver op.')
return {self._loom._loom_input_tensor: self.serialize()}