def plot_quantization()

in models/toy_sources/compression_model.py [0:0]


  def plot_quantization(self, intervals, figsize=None, **kwargs):
    if len(intervals) != self.ndim_source or self.ndim_source not in (1, 2):
      raise ValueError("This method is only defined for 1D or 2D models.")

    data = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
    data = tf.meshgrid(*data, indexing="ij")
    data = tf.stack(data, axis=-1)

    codebook, rates, indexes = self.quantize(data, **kwargs)
    codebook = codebook.numpy()
    rates = rates.numpy()
    indexes = indexes.numpy()

    data_dist = self.source.prob(data).numpy()
    counts = np.bincount(np.ravel(indexes), minlength=len(codebook))
    prior = 2 ** (-rates)

    if self.ndim_source == 1:
      data = np.squeeze(data, axis=-1)
      boundaries = np.nonzero(indexes[1:] != indexes[:-1])[0]
      boundaries = (data[boundaries] + data[boundaries + 1]) / 2
      plt.figure(figsize=figsize or (16, 8))
      plt.plot(data, data_dist, label="source")
      markers, stems, base = plt.stem(
          codebook[counts > 0], prior[counts > 0], label="codebook")
      plt.setp(markers, color="black")
      plt.setp(stems, color="black")
      plt.setp(base, linestyle="None")
      plt.xticks(np.sort(codebook[counts > 0]))
      plt.grid(False, axis="x")
      for r in boundaries:
        plt.axvline(
            r, color="black", lw=1, ls=":",
            label="boundaries" if r == boundaries[0] else None)
      plt.xlim(np.min(data), np.max(data))
      plt.ylim(bottom=-.01)
      plt.legend(loc="upper left")
      plt.xlabel("source space")
    else:
      google_pink = (0xf4/255, 0x39/255, 0xa0/255)
      plt.figure(figsize=figsize or (16, 14))
      vmax = data_dist.max()
      plt.imshow(
          data_dist, vmin=0, vmax=vmax, origin="lower",
          extent=(
              data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0]))
      plt.contour(
          data[:, :, 1], data[:, :, 0], indexes,
          np.arange(len(codebook)) + .5,
          colors=[google_pink], linewidths=.5)
      plt.plot(
          codebook[counts > 0, 1], codebook[counts > 0, 0],
          "o", color=google_pink)
      plt.axis("image")
      plt.grid(False)
      plt.xlim(data[0, 0, 1], data[0, -1, 1])
      plt.ylim(data[0, 0, 0], data[-1, 0, 0])
      plt.xlabel("source dimension 1")
      plt.ylabel("source dimension 2")