in jukebox/prior/prior.py [0:0]
def __init__(self, z_shapes, l_bins, encoder, decoder, level,
downs_t, strides_t, labels, prior_kwargs, x_cond_kwargs, y_cond_kwargs,
prime_kwargs, copy_input, labels_v3=False,
merged_decoder=False, single_enc_dec=False):
super().__init__()
self.use_tokens = prime_kwargs.pop('use_tokens')
self.n_tokens = prime_kwargs.pop('n_tokens')
self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction')
self.copy_input = copy_input
if self.copy_input:
prime_kwargs['bins'] = l_bins
self.z_shapes = z_shapes
self.levels = len(self.z_shapes)
self.z_shape = self.z_shapes[level]
self.level = level
assert level < self.levels, f"Total levels {self.levels}, got level {level}"
self.l_bins = l_bins
# Passing functions instead of the vqvae module to avoid getting params
self.encoder = encoder
self.decoder = decoder
# X conditioning
self.x_cond = (level != (self.levels - 1))
self.cond_level = level + 1
# Y conditioning
self.y_cond = labels
self.single_enc_dec = single_enc_dec
# X conditioning
if self.x_cond:
self.conditioner_blocks = nn.ModuleList()
conditioner_block = lambda _level: Conditioner(input_shape=z_shapes[_level],
bins=l_bins,
down_t=downs_t[_level],
stride_t=strides_t[_level],
**x_cond_kwargs)
if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)")
self.conditioner_blocks.append(conditioner_block(self.cond_level))
# Y conditioning
if self.y_cond:
self.n_time = self.z_shape[0] # Assuming STFT=TF order and raw=T1 order, so T is first dim
self.y_emb = LabelConditioner(n_time=self.n_time,include_time_signal=not self.x_cond,**y_cond_kwargs)
# Lyric conditioning
if single_enc_dec:
# Single encoder-decoder transformer
self.prior_shapes = [(self.n_tokens,), prior_kwargs.pop('input_shape')]
self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')]
self.prior_dims = [np.prod(shape) for shape in self.prior_shapes]
self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1]
self.prior_width = prior_kwargs['width']
print_once(f'Creating cond. autoregress with prior bins {self.prior_bins}, ')
print_once(f'dims {self.prior_dims}, ')
print_once(f'shift {self.prior_bins_shift}')
print_once(f'input shape {sum(self.prior_dims)}')
print_once(f'input bins {sum(self.prior_bins)}')
print_once(f'Self copy is {self.copy_input}')
self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[0], self.prior_dims[1]
self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
self.prior = ConditionalAutoregressive2D(input_shape=(sum(self.prior_dims),),
bins=sum(self.prior_bins),
x_cond=(self.x_cond or self.y_cond), y_cond=True,
prime_len=self.prime_loss_dims,
**prior_kwargs)
else:
# Separate encoder-decoder transformer
if self.n_tokens != 0 and self.use_tokens:
from jukebox.transformer.ops import Conv1D
prime_input_shape = (self.n_tokens,)
self.prime_loss_dims = np.prod(prime_input_shape)
self.prime_acts_width, self.prime_state_width = prime_kwargs['width'], prior_kwargs['width']
self.prime_prior = ConditionalAutoregressive2D(input_shape=prime_input_shape, x_cond=False, y_cond=False,
only_encode=True,
**prime_kwargs)
self.prime_state_proj = Conv1D(self.prime_acts_width, self.prime_state_width, init_scale=prime_kwargs['init_scale'])
self.prime_state_ln = LayerNorm(self.prime_state_width)
self.prime_bins = prime_kwargs['bins']
self.prime_x_out = nn.Linear(self.prime_state_width, self.prime_bins, bias=False)
nn.init.normal_(self.prime_x_out.weight, std=0.02 * prior_kwargs['init_scale'])
else:
self.prime_loss_dims = 0
self.gen_loss_dims = np.prod(self.z_shape)
self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
self.prior = ConditionalAutoregressive2D(x_cond=(self.x_cond or self.y_cond), y_cond=self.y_cond,
encoder_dims = self.prime_loss_dims, merged_decoder=merged_decoder,
**prior_kwargs)
self.n_ctx = self.gen_loss_dims
self.downsamples = calculate_strides(strides_t, downs_t)
self.cond_downsample = self.downsamples[level+1] if level != self.levels - 1 else None
self.raw_to_tokens = np.prod(self.downsamples[:level+1])
self.sample_length = self.n_ctx*self.raw_to_tokens
if labels:
self.labels_v3 = labels_v3
self.labeller = Labeller(self.y_emb.max_bow_genre_size, self.n_tokens, self.sample_length, v3=self.labels_v3)
else:
self.labeller = EmptyLabeller()
print(f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}")