def __init__()

in tensorflow_probability/python/bijectors/glow.py [0:0]


  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    parameters = dict(locals())
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we build a bijector that chains the components together
    # sequentially.
    super(Glow, self).__init__(
        bijectors=chain.Chain(glow_chain, validate_args=validate_args),
        validate_args=validate_args,
        parameters=parameters,
        name=name)