def plot_jacobians()

in models/toy_sources/ntc.py [0:0]


  def plot_jacobians(self, which, intervals, arrow_intervals, scale=2,
                     figsize=None):
    if not (len(intervals) == len(arrow_intervals) ==
            self.ndim_source == self.ndim_latent == 2):
      raise ValueError("This method is only defined for 2D models.")
    if which not in ("analysis", "synthesis"):
      raise ValueError("`which` must be 'analysis' or 'synthesis'.")

    data = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
    data = tf.meshgrid(*data, indexing="ij")
    data = tf.stack(data, axis=-1)
    data_dist = self.source.prob(data).numpy()

    if which == "analysis":
      arrow_data = [
          tf.linspace(float(i[0]), float(i[1]), int(i[2]))
          for i in arrow_intervals
      ]
      arrow_data = tf.meshgrid(*arrow_data, indexing="ij")
      arrow_data = tf.stack(arrow_data, axis=-1)
      arrow_data = tf.reshape(arrow_data, (-1, arrow_data.shape[-1]))
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(arrow_data)
        arrow_latents = self.analysis(arrow_data)
      # First dimension is batch, second is latent dim, third is source dim.
      jacobian = tape.batch_jacobian(arrow_latents, arrow_data)
      jacobian = tf.linalg.inv(jacobian)
      jacobian = tf.transpose(jacobian, (0, 2, 1))
    else:
      arrow_latents = [
          tf.linspace(float(i[0]), float(i[1]), int(i[2]))
          for i in arrow_intervals
      ]
      arrow_latents = tf.meshgrid(*arrow_latents, indexing="ij")
      arrow_latents = tf.stack(arrow_latents, axis=-1)
      arrow_latents = tf.reshape(arrow_latents, (-1, arrow_latents.shape[-1]))
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(arrow_latents)
        arrow_data = self.synthesis(arrow_latents)
      jacobian = tape.batch_jacobian(arrow_data, arrow_latents)
      jacobian = tf.transpose(jacobian, (0, 2, 1))

    google_pink = (0xf4/255, 0x39/255, 0xa0/255)
    google_purple = (0xa1/255, 0x42/255, 0xf4/255)

    plt.figure(figsize=figsize or (16, 14))
    plt.imshow(
        data_dist, vmin=0, vmax=data_dist.max(), origin="lower",
        extent=(data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0]))
    plt.quiver(
        arrow_data[:, 1], arrow_data[:, 0],
        jacobian[:, 0, 1], jacobian[:, 0, 0],
        pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots",
        color=google_pink, scale_units="xy", scale=scale,
    )
    plt.quiver(
        arrow_data[:, 1], arrow_data[:, 0],
        jacobian[:, 1, 1], jacobian[:, 1, 0],
        pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots",
        color=google_purple, scale_units="xy", scale=scale,
    )
    plt.axis("image")
    plt.grid(False)
    plt.xlim(data[0, 0, 1], data[0, -1, 1])
    plt.ylim(data[0, 0, 0], data[-1, 0, 0])
    plt.xlabel("source dimension 1")
    plt.ylabel("source dimension 2")