def discrete_bottleneck()

in tensor2tensor/layers/discretization.py [0:0]


def discrete_bottleneck(inputs,
                        hidden_size,
                        z_size,
                        filter_size,
                        mode=None,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        beta=0.25,
                        ema=True,
                        means=None,
                        ema_count=None,
                        ema_means=None,
                        epsilon=1e-5,
                        decay=0.999,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        softmax_k=0,
                        temperature_warmup_steps=150000,
                        do_hard_gumbel_softmax=False,
                        num_flows=0,
                        approximate_gs_entropy=False,
                        sum_over_latents=False,
                        discrete_mix=0.5,
                        noise_dev=1.,
                        startup_steps=50000,
                        summary=True,
                        name=None,
                        cond=True):
  """Discretization bottleneck.

  Args:
    inputs: Input to the bottleneck, a Tensor of shape [..., channels].
    hidden_size: Dimension of the dense output.
    z_size: Number of bits, where discrete codes range from 1 to 2**z_size.
    filter_size: Filter size in the embedding function.
    mode: tf.estimator.ModeKeys.
    bottleneck_kind: Kind of discretization bottleneck. One of dense, dvq
      (decomposed vector quantization), gumbel-softmax, gumbel-softmax-dvq,
      semhash, or vae.
    num_blocks: Number of blocks. Used only if bottleneck_kind is DVQ.
    num_residuals: Number of residual units used to compute nearest
      neighbors. Used only if bottleneck_kind is DVQ.
    reshape_method: Method to reshape. Used only if bottleneck_kind is DVQ.
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project.
    beta: Scale factor for codebook loss and EMA. Used only if bottleneck_kind
      is DVQ.
    ema: Whether to update embeddings using exponential moving averages. Used
      only if bottleneck_kind is DVQ.
    means: The embedding table. Used only if ema is True.
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to. Used only if ema is True.
    ema_means: Exponentially averaged version of the embeddings. Used only if
      ema is True.
    epsilon: Small value to avoid dividing by zero in EMA update. Used only if
      ema is True.
    decay: Decay factor for the exponential moving average. Used only if ema is
      True.
    random_top_k: Noisy top-k. Used only if bottleneck_kind is DVQ.
    soft_em: Whether to use soft EM or hard EM. Used only if bottleneck_kind is
      DVQ.
    num_samples: Number of samples for soft EM. Used only if soft_em is True.
    softmax_k: If > 0 then do top-k softmax. Used only if bottleneck_kind
      is gumbel-softmax.
    temperature_warmup_steps: Number of steps it takes to decay temperature to
      0. Used only if bottleneck_kind is gumbel-softmax or gumbel-softmax-dvq.
    do_hard_gumbel_softmax: Whether to use hard or soft Gumbel-Softmax
      samples. Used only if bottleneck_kind is gumbel-softmax-dvq.
    num_flows: Number of inverse autoregresive flows. Used only if
      bottleneck_kind is gumbel-softmax-dvq.
    approximate_gs_entropy: Whether to approximate the Gumbel-Softmax density
      as a categorical distribution when calculating the sample entropy. Used
      only if bottleneck_kind is gumbel-softmax-dvq.
    sum_over_latents: Whether to sum over all non-batch dimensions before
      taking mean of entropy loss term. Used only if bottleneck kind is DVQ
      or gumbel-softmax-dvq.
    discrete_mix: Factor for mixing discrete and non-discrete input. Used only
      if bottleneck_kind is semhash.
    noise_dev: Noise stddev. Used only if bottleneck_kind is semhash.
    startup_steps: Number of steps after which latent predictor is trained. Used
      only if bottleneck_kind is semhash.
    summary: Whether to write summaries.
    name: Name for the bottleneck scope.
    cond: A tf.bool condition on whether to update the codebook.

  Returns:
    outputs_dense: Tensor of shape [..., output_dim]. The output dimension is
      hidden_size if bottleneck_kind is gumbel-softmax, DVQ; filter_size if
      bottleneck_kind is dense, semhash, vae. If bottleneck_kind is DVQ,
      outputs_dense represents the codebook (means) indexed by outputs_discrete.
    outputs_discrete: Tensor of shape [...]. Discrete codes, each an index in
      [0, 2**z_size). It uses the hot representation if soft_em is True.
    extra_loss: Scalar Tensor. Sum of codebook and commitment losses if
      bottleneck_kind is DVQ; else zero.
    embed_fn: Function embed with arguments partially filled in.
    neg_q_entropy: Scalar Tensor representing negative entropy of variational
      approximation (0 if it is deterministic).

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if ema is True, or unknown args.
  """
  if bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
    assert means is not None
    if hidden_size % num_blocks != 0:
      raise ValueError("num_blocks does not divide hidden size")

    if z_size % num_residuals != 0:
      raise ValueError("num_residuals does not divide embedding table size")
    z_size_per_residual = int(z_size / num_residuals)

    if z_size_per_residual % num_blocks != 0:
      raise ValueError("num_blocks does not divide embedding table size")
    block_v_size = 2**int(z_size_per_residual / num_blocks)

    if ema:
      if ema_count is None:
        raise ValueError("ema_count is None but ema is True")
      if ema_means is None:
        raise ValueError("ema_means is None but ema is True")
  else:
    block_v_size = None

  with tf.variable_scope(
      name, default_name="discrete_bottleneck", reuse=tf.AUTO_REUSE):
    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        bottleneck_kind=bottleneck_kind,
        soft_em=soft_em,
        num_blocks=num_blocks,
        num_residuals=num_residuals,
        block_v_size=block_v_size,
        means=means,
        name=name)

    if bottleneck_kind == "dense":
      # Note discrete output is continuous here.
      outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
      outputs_dense = tf.layers.dense(
          outputs_discrete, filter_size, name="vch1")
      extra_loss = tf.constant(0.0)
      neg_q_entropy = tf.constant(0.0)
    elif bottleneck_kind in ["dvq", "gumbel-softmax-dvq"]:
      inputs_3d = inputs
      if len(inputs.shape) == 4:
        inputs_3d = tf.squeeze(inputs, axis=2)
      if reshape_method == "slice":
        x_reshaped = slice_hidden(
            inputs_3d, hidden_size=hidden_size, num_blocks=num_blocks)
      elif reshape_method == "project":
        if projection_tensors is None:
          raise ValueError(
              "Projection tensors is None for reshape_method project")
        x_reshaped = project_hidden(
            inputs_3d,
            projection_tensors=projection_tensors,
            hidden_size=hidden_size,
            num_blocks=num_blocks)
      else:
        raise ValueError("Unknown reshape_method")

      x_res = tf.reshape(x_reshaped,
                         [-1] + common_layers.shape_list(x_reshaped)[2:])
      x_means_hot = []
      x_means = 0
      extra_loss = 0
      for i in range(num_residuals):
        x_means_hot_res, x_means_res, q_loss_res, e_loss_res, neg_q_entropy = (
            embedding_lookup(
                x_reshaped,
                means=means[i],
                num_blocks=num_blocks,
                block_v_size=block_v_size,
                bottleneck_kind=bottleneck_kind,
                random_top_k=random_top_k,
                soft_em=soft_em,
                num_samples=num_samples,
                temperature_warmup_steps=temperature_warmup_steps,
                do_hard_gumbel_softmax=do_hard_gumbel_softmax,
                num_flows=num_flows,
                approximate_gs_entropy=approximate_gs_entropy,
                sum_over_latents=sum_over_latents))
        # Update the EMA variables.
        if ema:
          tf.logging.info("Using EMA with beta = {}".format(beta))
          updated_ema_count_res = moving_averages.assign_moving_average(
              ema_count[i],
              tf.where(cond,
                       tf.reduce_sum(
                           tf.reshape(x_means_hot_res,
                                      shape=[-1, num_blocks, block_v_size]),
                           axis=0), ema_count[i]),
              decay,
              zero_debias=False)

          dw = tf.matmul(
              tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
              tf.transpose(x_res, perm=[1, 0, 2]))

          updated_ema_means_res = moving_averages.assign_moving_average(
              ema_means[i], tf.where(cond, dw, ema_means[i]),
              decay, zero_debias=False)
          n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
          updated_ema_count_res = (
              (updated_ema_count_res + epsilon) / (n + 2**z_size * epsilon) * n)
          # pylint: disable=g-no-augmented-assignment
          updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
              updated_ema_count_res, axis=-1)
          # pylint: enable=g-no-augmented-assignment

          with tf.control_dependencies([e_loss_res]):
            update_means_res = tf.assign(means[i],
                                         tf.where(cond,
                                                  updated_ema_means_res,
                                                  means[i]))
            with tf.control_dependencies([update_means_res]):
              extra_loss += beta * e_loss_res
        else:
          extra_loss += q_loss_res + beta * e_loss_res

        # Update the residuals.
        x_res -= x_means_res
        x_means += x_means_res
        x_means_hot.append(x_means_hot_res)

      # Get the discrete latent representation.
      x_means_hot = tf.stack(x_means_hot, axis=1)
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation.
      x_means_bits = int_to_bit(
          x_means_idx,
          num_bits=int(z_size / (num_residuals * num_blocks)),
          base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-2]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      outputs_discrete = bit_to_int(
          tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of discrete outputs.
      inputs_shape = common_layers.shape_list(inputs)
      outputs_discrete = tf.reshape(outputs_discrete, inputs_shape[:-1])

      # If we're using soft EM then set discretes to the hot representation.
      if soft_em:
        outputs_discrete = x_means_hot
        outputs_discrete = tf.reshape(outputs_discrete,
                                      inputs_shape[:-1] + [block_v_size])

      # Reshape assuming hidden_size == inputs_shape[:-1].
      x_means = tf.reshape(x_means, inputs_shape)
      outputs_dense = inputs + tf.stop_gradient(x_means - inputs)
    elif bottleneck_kind == "gumbel-softmax":
      _, outputs_hot, extra_loss = gumbel_softmax(
          inputs,
          z_size=z_size,
          mode=mode,
          softmax_k=softmax_k,
          temperature_warmup_steps=temperature_warmup_steps,
          summary=summary,
          name=name)
      outputs_discrete = tf.argmax(outputs_hot, axis=-1)
      outputs_dense = tf.layers.dense(
          outputs_hot, hidden_size, name="dae_dense")
      neg_q_entropy = tf.constant(0.0)
    elif bottleneck_kind == "semhash":
      outputs_discrete = tf.layers.dense(inputs, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(outputs_discrete)
      if summary:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(outputs_discrete),
            mean=0.0,
            stddev=noise_dev)
        y = common_layers.saturating_sigmoid(outputs_discrete + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      outputs_dense_a = tf.layers.dense(c, filter_size, name="vch1a")
      outputs_dense_b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      outputs_dense = outputs_dense_a + outputs_dense_b
      outputs_dense = tf.layers.dense(outputs_dense, hidden_size,
                                      name="vch_final_linear")

      dx = tf.to_int32(tf.stop_gradient(d))
      outputs_discrete = bit_to_int(dx, z_size)
      extra_loss = tf.constant(0.0)
      neg_q_entropy = tf.constant(0.0)
    elif bottleneck_kind == "vae":
      outputs_discrete, extra_loss, _, _ = vae(inputs, z_size, name="vae")
      outputs_dense = tf.layers.dense(
          outputs_discrete, filter_size, name="vch1")
      neg_q_entropy = tf.constant(0.0)
    else:
      raise ValueError("Unknown discretization method.")

  return outputs_dense, outputs_discrete, extra_loss, embed_fn, neg_q_entropy