in jukebox/align.py [0:0]
def get_alignment(x, zs, labels, prior, fp16, hps):
level = hps.levels - 1 # Top level used
n_ctx, n_tokens = prior.n_ctx, prior.n_tokens
z = zs[level]
bs, total_length = z.shape[0], z.shape[1]
if total_length < n_ctx:
padding_length = n_ctx - total_length
z = t.cat([z, t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device)], dim=1)
total_length = z.shape[1]
else:
padding_length = 0
hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
n_head = prior.prior.transformer.n_head
alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer
attn_layers = set([alignment_layer])
alignment_hops = {}
indices_hops = {}
prior.cuda()
empty_cache()
for start in get_starts(total_length, n_ctx, hop_length):
end = start + n_ctx
# set y offset, sample_length and lyrics tokens
y, indices_hop = prior.get_y(labels, start, get_indices=True)
assert len(indices_hop) == bs
for indices in indices_hop:
assert len(indices) == n_tokens
z_bs = t.chunk(z, bs, dim=0)
y_bs = t.chunk(y, bs, dim=0)
w_hops = []
for z_i, y_i in zip(z_bs, y_bs):
w_hop = prior.z_forward(z_i[:,start:end], [], y_i, fp16=fp16, get_attn_weights=attn_layers)
assert len(w_hop) == 1
w_hops.append(w_hop[0][:, alignment_head])
del w_hop
w = t.cat(w_hops, dim=0)
del w_hops
assert_shape(w, (bs, n_ctx, n_tokens))
alignment_hop = w.float().cpu().numpy()
assert_shape(alignment_hop, (bs, n_ctx, n_tokens))
del w
# alignment_hop has shape (bs, n_ctx, n_tokens)
# indices_hop is a list of len=bs, each entry of len hps.n_tokens
indices_hops[start] = indices_hop
alignment_hops[start] = alignment_hop
prior.cpu()
empty_cache()
# Combine attn for each hop into attn for full range
# Use indices to place them into correct place for corresponding source tokens
alignments = []
for item in range(bs):
# Note each item has different length lyrics
full_tokens = labels['info'][item]['full_tokens']
alignment = np.zeros((total_length, len(full_tokens) + 1))
for start in reversed(get_starts(total_length, n_ctx, hop_length)):
end = start + n_ctx
alignment_hop = alignment_hops[start][item]
indices = indices_hops[start][item]
assert len(indices) == n_tokens
assert alignment_hop.shape == (n_ctx, n_tokens)
alignment[start:end,indices] = alignment_hop
alignment = alignment[:total_length - padding_length,:-1] # remove token padding, and last lyric index
alignments.append(alignment)
return alignments