def plot_transfer()

in models/toy_sources/ntc.py [0:0]


  def plot_transfer(self, intervals, figsize=None, soft_round=None, **kwargs):
    if not len(intervals) == self.ndim_source == self.ndim_latent == 1:
      raise ValueError("This method is only defined for 1D models.")
    if soft_round is None:
      soft_round = self.soft_round[1]

    x = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
    x = tf.meshgrid(*x, indexing="ij")
    x = tf.stack(x, axis=-1)

    y_hat, _, _ = self.encode_decode(x, False, False, soft_round, **kwargs)

    y = self.analysis(x)
    # We feed y here so we can visualize the full behavior of the synthesis
    # transform (not just at the quantized latent values).
    x_hat = self.synthesis(y)

    x = np.squeeze(x.numpy(), -1)
    y = np.squeeze(y.numpy(), -1)
    x_hat = np.squeeze(x_hat.numpy(), -1)
    y_hat = np.squeeze(y_hat.numpy(), -1)

    ylim = np.min(y), np.max(y)

    boundaries = np.nonzero(y_hat[1:] != y_hat[:-1])[0]
    lboundaries = (y_hat[boundaries] + y_hat[boundaries + 1]) / 2
    dboundaries = (x[boundaries] + x[boundaries + 1]) / 2

    lcodebook = np.unique(y_hat)
    dcodebook = self.synthesis(lcodebook[:, None]).numpy()
    dcodebook = np.squeeze(dcodebook, -1)
    mask = np.logical_and(ylim[0] < lcodebook, lcodebook < ylim[1])
    lcodebook = lcodebook[mask]
    dcodebook = dcodebook[mask]

    plt.figure(figsize=figsize or (16, 14))
    plt.plot(x, y, label="analysis transform")
    plt.plot(x_hat, y, label="synthesis transform")

    plt.gca().set_aspect("equal", "box")
    # Flip y axis if latent space is reversed.
    if y[0] > y[-1]:
      plt.gca().invert_yaxis()
    plt.xticks(dcodebook)
    plt.yticks(lcodebook)
    plt.grid(False)
    plt.xlabel("source space")
    plt.ylabel("latent space")

    xmin = plt.axis()[0]
    ymin = plt.axis()[2]
    for x, y in zip(dcodebook, lcodebook):
      plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1)
      plt.plot(
          [x], [y], "black", marker="o", ms=5, lw=1,
          label="codebook" if x == dcodebook[0] else None)
    for x, y in zip(dboundaries, lboundaries):
      plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1, ls=":")
      plt.plot(
          [x], [y], "black", marker="o", ms=3, lw=1, ls=":",
          label="boundaries" if x == dboundaries[0] else None)

    plt.legend(loc="upper left")