in models/toy_sources/compression_model.py [0:0]
def plot_quantization(self, intervals, figsize=None, **kwargs):
if len(intervals) != self.ndim_source or self.ndim_source not in (1, 2):
raise ValueError("This method is only defined for 1D or 2D models.")
data = [tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals]
data = tf.meshgrid(*data, indexing="ij")
data = tf.stack(data, axis=-1)
codebook, rates, indexes = self.quantize(data, **kwargs)
codebook = codebook.numpy()
rates = rates.numpy()
indexes = indexes.numpy()
data_dist = self.source.prob(data).numpy()
counts = np.bincount(np.ravel(indexes), minlength=len(codebook))
prior = 2 ** (-rates)
if self.ndim_source == 1:
data = np.squeeze(data, axis=-1)
boundaries = np.nonzero(indexes[1:] != indexes[:-1])[0]
boundaries = (data[boundaries] + data[boundaries + 1]) / 2
plt.figure(figsize=figsize or (16, 8))
plt.plot(data, data_dist, label="source")
markers, stems, base = plt.stem(
codebook[counts > 0], prior[counts > 0], label="codebook")
plt.setp(markers, color="black")
plt.setp(stems, color="black")
plt.setp(base, linestyle="None")
plt.xticks(np.sort(codebook[counts > 0]))
plt.grid(False, axis="x")
for r in boundaries:
plt.axvline(
r, color="black", lw=1, ls=":",
label="boundaries" if r == boundaries[0] else None)
plt.xlim(np.min(data), np.max(data))
plt.ylim(bottom=-.01)
plt.legend(loc="upper left")
plt.xlabel("source space")
else:
google_pink = (0xf4/255, 0x39/255, 0xa0/255)
plt.figure(figsize=figsize or (16, 14))
vmax = data_dist.max()
plt.imshow(
data_dist, vmin=0, vmax=vmax, origin="lower",
extent=(
data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0]))
plt.contour(
data[:, :, 1], data[:, :, 0], indexes,
np.arange(len(codebook)) + .5,
colors=[google_pink], linewidths=.5)
plt.plot(
codebook[counts > 0, 1], codebook[counts > 0, 0],
"o", color=google_pink)
plt.axis("image")
plt.grid(False)
plt.xlim(data[0, 0, 1], data[0, -1, 1])
plt.ylim(data[0, 0, 0], data[-1, 0, 0])
plt.xlabel("source dimension 1")
plt.ylabel("source dimension 2")