in jukebox/vqvae/vqvae.py [0:0]
def __init__(self, input_shape, levels, downs_t, strides_t,
emb_width, l_bins, mu, commit, spectral, multispectral,
multipliers=None, use_bottleneck=True, **block_kwargs):
super().__init__()
self.sample_length = input_shape[0]
x_shape, x_channels = input_shape[:-1], input_shape[-1]
self.x_shape = x_shape
self.downsamples = calculate_strides(strides_t, downs_t)
self.hop_lengths = np.cumprod(self.downsamples)
self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)]
self.levels = levels
if multipliers is None:
self.multipliers = [1] * levels
else:
assert len(multipliers) == levels, "Invalid number of multipliers"
self.multipliers = multipliers
def _block_kwargs(level):
this_block_kwargs = dict(block_kwargs)
this_block_kwargs["width"] *= self.multipliers[level]
this_block_kwargs["depth"] *= self.multipliers[level]
return this_block_kwargs
encoder = lambda level: Encoder(x_channels, emb_width, level + 1,
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
decoder = lambda level: Decoder(x_channels, emb_width, level + 1,
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
for level in range(levels):
self.encoders.append(encoder(level))
self.decoders.append(decoder(level))
if use_bottleneck:
self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels)
else:
self.bottleneck = NoBottleneck(levels)
self.downs_t = downs_t
self.strides_t = strides_t
self.l_bins = l_bins
self.commit = commit
self.spectral = spectral
self.multispectral = multispectral