in tensorflow_fold/blocks/blocks.py [0:0]
def _update_input_type(self):
if self.input_type.size is None:
raise TypeError('Concat inputs must be tensors: %s' % self.input_type)
if (not self._flatten and
any(isinstance(t, tdt.TupleType) for t in self.input_type)):
raise TypeError('input type %s contains nested tuples, expected a flat '
'tuple of tensors; set flatten=True in the constructor' %
self.input_type)
size = 0
shape = None
dtype = None
self._scalar_indices = []
for (i, ty) in enumerate(self.input_type.terminal_types()):
tyshape = list(ty.shape) # clone original shape
tyrank = len(tyshape)
if tyrank == 0:
tyshape = [1] # upgrade scalars to vectors
tyrank = 1
self._scalar_indices.append(i)
if tyrank <= self._concat_dim:
raise TypeError('Concat argument %d of type %s has rank less than %d.' %
(i, ty, self._concat_dim+1))
size += tyshape[self._concat_dim]
tyshape[self._concat_dim] = None # for shape matching
if not shape:
shape = tyshape
elif shape != tyshape:
raise TypeError('Shapes for concat don\'t match: %s vs. %s'
% (shape, tyshape))
if not dtype:
dtype = ty.dtype
elif ty.dtype != dtype:
raise TypeError('Cannot concat tensors of different dtypes: %s vs. %s'
% (dtype, ty.dtype))
if not dtype:
raise TypeError('Concat requires at least one tensor as input')
shape[self._concat_dim] = size
self.set_output_type(tdt.TensorType(shape, dtype=dtype))