def _update_input_type()

in tensorflow_fold/blocks/blocks.py [0:0]


  def _update_input_type(self):
    if self.input_type.size is None:
      raise TypeError('Concat inputs must be tensors: %s' % self.input_type)
    if (not self._flatten and
        any(isinstance(t, tdt.TupleType) for t in self.input_type)):
      raise TypeError('input type %s contains nested tuples, expected a flat '
                      'tuple of tensors; set flatten=True in the constructor' %
                      self.input_type)

    size = 0
    shape = None
    dtype = None
    self._scalar_indices = []
    for (i, ty) in enumerate(self.input_type.terminal_types()):
      tyshape = list(ty.shape)         # clone original shape
      tyrank = len(tyshape)
      if tyrank == 0:
        tyshape = [1]                  # upgrade scalars to vectors
        tyrank = 1
        self._scalar_indices.append(i)
      if tyrank <= self._concat_dim:
        raise TypeError('Concat argument %d of type %s has rank less than %d.' %
                        (i, ty, self._concat_dim+1))
      size += tyshape[self._concat_dim]
      tyshape[self._concat_dim] = None  # for shape matching

      if not shape:
        shape = tyshape
      elif shape != tyshape:
        raise TypeError('Shapes for concat don\'t match: %s vs. %s'
                        % (shape, tyshape))
      if not dtype:
        dtype = ty.dtype
      elif ty.dtype != dtype:
        raise TypeError('Cannot concat tensors of different dtypes: %s vs. %s'
                        % (dtype, ty.dtype))
    if not dtype:
      raise TypeError('Concat requires at least one tensor as input')
    shape[self._concat_dim] = size
    self.set_output_type(tdt.TensorType(shape, dtype=dtype))