in models/toy_sources/ntc.py [0:0]
def plot_jacobians(self, which, intervals, arrow_intervals, scale=2,
figsize=None):
if not (len(intervals) == len(arrow_intervals) ==
self.ndim_source == self.ndim_latent == 2):
raise ValueError("This method is only defined for 2D models.")
if which not in ("analysis", "synthesis"):
raise ValueError("`which` must be 'analysis' or 'synthesis'.")
data = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
data = tf.meshgrid(*data, indexing="ij")
data = tf.stack(data, axis=-1)
data_dist = self.source.prob(data).numpy()
if which == "analysis":
arrow_data = [
tf.linspace(float(i[0]), float(i[1]), int(i[2]))
for i in arrow_intervals
]
arrow_data = tf.meshgrid(*arrow_data, indexing="ij")
arrow_data = tf.stack(arrow_data, axis=-1)
arrow_data = tf.reshape(arrow_data, (-1, arrow_data.shape[-1]))
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(arrow_data)
arrow_latents = self.analysis(arrow_data)
# First dimension is batch, second is latent dim, third is source dim.
jacobian = tape.batch_jacobian(arrow_latents, arrow_data)
jacobian = tf.linalg.inv(jacobian)
jacobian = tf.transpose(jacobian, (0, 2, 1))
else:
arrow_latents = [
tf.linspace(float(i[0]), float(i[1]), int(i[2]))
for i in arrow_intervals
]
arrow_latents = tf.meshgrid(*arrow_latents, indexing="ij")
arrow_latents = tf.stack(arrow_latents, axis=-1)
arrow_latents = tf.reshape(arrow_latents, (-1, arrow_latents.shape[-1]))
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(arrow_latents)
arrow_data = self.synthesis(arrow_latents)
jacobian = tape.batch_jacobian(arrow_data, arrow_latents)
jacobian = tf.transpose(jacobian, (0, 2, 1))
google_pink = (0xf4/255, 0x39/255, 0xa0/255)
google_purple = (0xa1/255, 0x42/255, 0xf4/255)
plt.figure(figsize=figsize or (16, 14))
plt.imshow(
data_dist, vmin=0, vmax=data_dist.max(), origin="lower",
extent=(data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0]))
plt.quiver(
arrow_data[:, 1], arrow_data[:, 0],
jacobian[:, 0, 1], jacobian[:, 0, 0],
pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots",
color=google_pink, scale_units="xy", scale=scale,
)
plt.quiver(
arrow_data[:, 1], arrow_data[:, 0],
jacobian[:, 1, 1], jacobian[:, 1, 0],
pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots",
color=google_purple, scale_units="xy", scale=scale,
)
plt.axis("image")
plt.grid(False)
plt.xlim(data[0, 0, 1], data[0, -1, 1])
plt.ylim(data[0, 0, 0], data[-1, 0, 0])
plt.xlabel("source dimension 1")
plt.ylabel("source dimension 2")