in models/toy_sources/ntc.py [0:0]
def plot_transfer(self, intervals, figsize=None, soft_round=None, **kwargs):
if not len(intervals) == self.ndim_source == self.ndim_latent == 1:
raise ValueError("This method is only defined for 1D models.")
if soft_round is None:
soft_round = self.soft_round[1]
x = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
x = tf.meshgrid(*x, indexing="ij")
x = tf.stack(x, axis=-1)
y_hat, _, _ = self.encode_decode(x, False, False, soft_round, **kwargs)
y = self.analysis(x)
# We feed y here so we can visualize the full behavior of the synthesis
# transform (not just at the quantized latent values).
x_hat = self.synthesis(y)
x = np.squeeze(x.numpy(), -1)
y = np.squeeze(y.numpy(), -1)
x_hat = np.squeeze(x_hat.numpy(), -1)
y_hat = np.squeeze(y_hat.numpy(), -1)
ylim = np.min(y), np.max(y)
boundaries = np.nonzero(y_hat[1:] != y_hat[:-1])[0]
lboundaries = (y_hat[boundaries] + y_hat[boundaries + 1]) / 2
dboundaries = (x[boundaries] + x[boundaries + 1]) / 2
lcodebook = np.unique(y_hat)
dcodebook = self.synthesis(lcodebook[:, None]).numpy()
dcodebook = np.squeeze(dcodebook, -1)
mask = np.logical_and(ylim[0] < lcodebook, lcodebook < ylim[1])
lcodebook = lcodebook[mask]
dcodebook = dcodebook[mask]
plt.figure(figsize=figsize or (16, 14))
plt.plot(x, y, label="analysis transform")
plt.plot(x_hat, y, label="synthesis transform")
plt.gca().set_aspect("equal", "box")
# Flip y axis if latent space is reversed.
if y[0] > y[-1]:
plt.gca().invert_yaxis()
plt.xticks(dcodebook)
plt.yticks(lcodebook)
plt.grid(False)
plt.xlabel("source space")
plt.ylabel("latent space")
xmin = plt.axis()[0]
ymin = plt.axis()[2]
for x, y in zip(dcodebook, lcodebook):
plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1)
plt.plot(
[x], [y], "black", marker="o", ms=5, lw=1,
label="codebook" if x == dcodebook[0] else None)
for x, y in zip(dboundaries, lboundaries):
plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1, ls=":")
plt.plot(
[x], [y], "black", marker="o", ms=3, lw=1, ls=":",
label="boundaries" if x == dboundaries[0] else None)
plt.legend(loc="upper left")