transformers/llm/export/utils/token2wav.py (412 lines of code) (raw):
import os
import torch
import torch.nn.functional as F
torch.set_printoptions(precision=4, sci_mode=False)
from .model_mapper import ModelMapper
from .transformers import Rotary, Embedding, Decoder, Attention
from .spinner import spinner_run
class Token2Wav(torch.nn.Module):
def __init__(self,token2wav, base):
super().__init__()
self.args = base.args
self.token2wav = token2wav.float()
self.config = base.config
self.llm_config = base.llm_config
self.rope_ratio = 1.0
self.quant_bit = 8
self.load()
def load(self):
raise NotImplementedError
def add_token_embeds(self, thinker_embeds):
raise NotImplementedError
def add_hidden_states(self, thinker_hidden_states):
raise NotImplementedError
def add_generate_ids(self, token_id):
raise NotImplementedError
def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
raise NotImplementedError
def export(self, onnx_path):
raise NotImplementedError
class UpSample1d(torch.nn.Module):
def __init__(self, upsample, channel):
super().__init__()
self.ratio = upsample.ratio
self.stride = upsample.stride
self.pad = upsample.pad
self.pad_left = upsample.pad_left
self.pad_right = upsample.pad_right
self.filter = upsample.filter.expand(channel, -1, -1).clone()
self.channel = channel
def forward(self, x):
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(x, self.filter, stride=self.stride, groups=self.channel)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(torch.nn.Module):
def __init__(self, downsample, channel):
super().__init__()
self.pad_left = downsample.pad_left
self.pad_right = downsample.pad_right
self.stride = downsample.stride
self.filter = downsample.filter.expand(channel, -1, -1).clone()
self.channel = channel
def forward(self, x):
x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate")
out = F.conv1d(x, self.filter, stride=self.stride, groups=self.channel)
return out
class TorchActivation1d(torch.nn.Module):
def __init__(
self,
activation
):
super().__init__()
self.act = activation.act
channel = self.act.in_features
self.upsample = UpSample1d(activation.upsample, channel)
self.downsample = DownSample1d(activation.downsample, channel)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# DiT model code
class ECAPA_TDNN(torch.nn.Module):
def __init__(self, spk_encoder):
super().__init__()
self.blocks = spk_encoder.blocks
self.mfa = spk_encoder.mfa
self.asp = spk_encoder.asp
self.fc = spk_encoder.fc
def forward(self, x):
# Minimize transpose for efficiency
x = x.transpose(1, 2)
xl = []
for layer in self.blocks:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
# Attentive Statistical Pooling
x = self.asp(x)
# Final linear transformation
x = self.fc(x)
# x = x.squeeze(-1) # avoid If when export to onnx
x = x.permute(0, 2, 1)
return x
class DitRotary(Rotary):
def __init__(self):
super().__init__(None)
self.model_type = 'dit'
self.rope_theta = 10000
self.rotary_dim = 64
self.theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
def forward(self, position_ids):
position_ids = position_ids.float().reshape(-1, 1)
idx_theta = position_ids * self.theta
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
rotary_pos_emb = torch.stack((rotary_pos_emb, rotary_pos_emb), dim=-1)
rotary_pos_emb = rotary_pos_emb.reshape(*rotary_pos_emb.shape[:-2], -1)
rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1)
return rotary_pos_emb
@staticmethod
def apply_rotary_pos(x, cos, sin):
def rotate_half(x):
x = x.reshape(*x.shape[:-1], -1, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return x.reshape(*x.shape[:-2], -1)
x = (x * cos) + (rotate_half(x) * sin)
return x
import math
class DiTAttention(torch.nn.Module):
def __init__(self, attn):
super().__init__()
self.dim = attn.dim
self.heads = attn.heads
self.inner_dim = attn.inner_dim
self.to_q = attn.to_q
self.to_k = attn.to_k
self.to_v = attn.to_v
self.to_out = attn.to_out
def forward(
self,
x,
rope=None,
mask=None,
) -> torch.Tensor:
batch_size = x.shape[0]
# `sample` projections.
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim)
key = key.view(batch_size, -1, self.heads, head_dim)
value = value.view(batch_size, -1, self.heads, head_dim)
# apply rotary position embedding
# Due to training process, only first head is applied with RoPE, will be fixed at next release
cos, sin = rope[0], rope[1]
first_query = query[:, :, :1, :]
first_key = key[:, :, :1, :]
other_query = query[:, :, 1:, :]
other_key = key[:, :, 1:, :]
first_query = DitRotary.apply_rotary_pos(first_query, cos, sin)
first_key = DitRotary.apply_rotary_pos(first_key, cos, sin)
query = torch.concat([first_query, other_query], dim=2)
key = torch.concat([first_key, other_key], dim=2)
attention_mask = (~mask) * torch.finfo(torch.float32).min
query = query.transpose(1, 2)
key = key.permute([0, 2, 3, 1])
value = value.transpose(1, 2)
attn_weights = torch.matmul(query, key) / math.sqrt(head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
x = attn_output.transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
x = x.reshape(batch_size, -1, self.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = self.to_out[0](x)
# dropout
x = self.to_out[1](x)
return x
class DiTBlock(torch.nn.Module):
def __init__(self, block):
super().__init__()
self.attn_norm = block.attn_norm
self.attn = DiTAttention(block.attn)
self.attn_ = block.attn
self.look_ahead_block = block.look_ahead_block
self.look_backward_block = block.look_backward_block
self.ff_norm = block.ff_norm
self.ff = block.ff
def forward(self, x, t, rope=None, block_diff=None):
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
attn_output = self.attn(
x=norm,
rope=rope,
mask=(block_diff >= -float(self.look_backward_block)) & (block_diff <= float(self.look_ahead_block)),
)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
class DitPreprocess(torch.nn.Module):
def __init__(self, dit):
super().__init__()
self.code_embed = dit.code_embed
self.input_proj = dit.proj_in_other
self.rotary_embed = DitRotary()
self.block_size = 24
def forward(self, cond, spk, code):
max_duration = code.shape[1] * 2
spk = spk.repeat(1, max_duration, 1)
cond = cond.repeat(1, max_duration, 1)
code_embed = self.code_embed(code)
input_embeds = torch.cat((cond, code_embed, spk), dim=-1)
code_embeds = self.input_proj(input_embeds)
position_ids = torch.arange(max_duration)
rope = self.rotary_embed(position_ids)
block_indices = position_ids // self.block_size
block_i = block_indices.unsqueeze(1)
block_j = block_indices.unsqueeze(0)
block_diff = block_j - block_i
mask = block_diff.reshape(1, 1, max_duration, max_duration)
return code_embeds, rope, mask
class DitWrapper(torch.nn.Module):
def __init__(self, dit):
super().__init__()
self.dit = dit
self.cfg = False
self.time_embed = dit.time_embed
self.code_embed = dit.text_embed
self.rotary_embed = DitRotary()
self.transformer_blocks = torch.nn.ModuleList()
for i in range(len(dit.transformer_blocks)):
self.transformer_blocks.append(DiTBlock(dit.transformer_blocks[i]))
self._create_block_diff = dit._create_block_diff
self.norm_out = dit.norm_out
self.proj_out = dit.proj_out
proj_in = dit.input_embed.proj
oc, ic = proj_in.weight.shape
x_ic = 80
other_ic = ic - x_ic
self.proj_in_x = torch.nn.Linear(x_ic, oc)
self.proj_in_x.weight.data = proj_in.weight[:, :x_ic]
self.proj_in_x.bias = None
self.proj_in_other = torch.nn.Linear(other_ic, oc)
self.proj_in_other.weight.data = proj_in.weight[:, x_ic:]
self.proj_in_other.bias = proj_in.bias
self.spk_encoder = ECAPA_TDNN(dit.input_embed.spk_encoder)
self.preprocess = DitPreprocess(self)
def spk_encode(self, spk):
return self.spk_encoder(spk)
def forward(self, x, code_embeds, rope, mask, time):
t = self.time_embed(time)
hidden = self.proj_in_x(x) + code_embeds
for block in self.transformer_blocks:
hidden = block(hidden, t, rope=rope, block_diff=mask)
hidden = self.norm_out(hidden, t)
output = self.proj_out(hidden)
return output
# end
class Qwen2_5OmniToken2Wav(Token2Wav):
def __init__(self, token2wav, base):
super().__init__(token2wav, base)
def load(self):
self.dit = self.token2wav.code2wav_dit_model
self.bigvgan = self.token2wav.code2wav_bigvgan_model
# some code change for export
self.dit = DitWrapper(self.dit)
# bigvgan.resblocks.activations.up/downsample contain conv weight channel by input
for i in range(len(self.bigvgan.resblocks)):
for j in range(len(self.bigvgan.resblocks[i].activations)):
old_act = self.bigvgan.resblocks[i].activations[j]
self.bigvgan.resblocks[i].activations[j] = TorchActivation1d(old_act)
self.bigvgan.activation_post = TorchActivation1d(self.bigvgan.activation_post)
# spk
path = os.path.join(self.args.path, 'spk_dict.pt')
self.speaker_map = {}
for key, value in torch.load(path).items():
spk = value["cond"].float()
cond = value['ref_mel'].float()
value.pop("ref_mel", None)
value['spk'] = spk.unsqueeze(1)
value['cond'] =self.dit.spk_encode(cond)
self.speaker_map[key] = value
spk = "Chelsie"
self.speaker_params = self.speaker_map[spk]
def dit_forward(self, code, initial_noise = None):
spk = self.speaker_params["spk"].float()
cond = self.speaker_params["cond"].float()
max_duration = code.shape[1] * 2
code_embeds, rope, mask = self.dit.preprocess(cond, spk, code)
def func(t, x):
pred = self.dit(x=x, code_embeds=code_embeds, rope=rope, mask=mask, time=torch.tensor([t]))
return pred
steps = 5
t = torch.linspace(0, 1, steps, dtype=cond.dtype)
t = 1 - torch.cos(torch.pi / 2 * t)
if initial_noise is None:
torch.manual_seed(42)
y0 = torch.randn([1, max_duration, 80], dtype=cond.dtype)
else:
y0 = initial_noise.clone()
for t0, t1 in zip(t[:-1], t[1:]):
dt = t1 - t0
k1 = func(t0, y0)
k2 = func(t0 + dt * 1/3, y0 + dt * k1 * 1/3)
k3 = func(t0 + dt * 2/3, y0 + dt * (k2 - k1 * 2/3))
k4 = func(t1, y0 + dt * (k1 - k2 + k3))
dy = (k1 + 3 * (k2 + k3) + k4) * dt * 0.125
y0 += dy
generated_mel = y0.permute(0, 2, 1)
# print('generated_mel = ', generated_mel, generated_mel.shape)
# print('generated_mel.shape = ', generated_mel.shape)
return generated_mel
@torch.no_grad()
def generate(self, code):
generated_mel = self.dit_forward(code)
waveform = self.bigvgan(generated_mel)
return waveform
@torch.no_grad()
def generate_stream(self, code):
# Defeine dit streaming parameters
dit_chunk_size = 48
dit_left_context = 24
dit_right_context = 12
dit_left_padding = 0
dit_right_padding = dit_right_context
dit_start_index = 0
dit_mel_len = 0
# Define vocoder streaming parameters
vocoder_left_context = 10
vocoder_right_context = 10
vocoder_left_pad = 0
vocoder_right_pad = vocoder_right_context
vocoder_upsample_rate = 240
torch.manual_seed(42)
initial_noise = torch.randn([1, 30000, 80], dtype=torch.float32)
code_buffer = torch.full((1, 0), 0, dtype=torch.long, device=code.device)
mel_buffer = torch.full((1, 80, 0), 0, dtype=torch.float32, device=code.device)
waveform_buffer = torch.full((0,), 0, dtype=torch.float32)
for next_code in code[0]:
code_buffer = torch.cat([code_buffer, next_code.reshape(1, 1)], dim=1)
if code_buffer.size(1) == dit_left_padding + dit_chunk_size + dit_right_padding:
# dit
generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2])
generated_mel = generated_mel[:, :, dit_left_padding * 2: -dit_right_padding * 2]
dit_left_padding = dit_left_context
code_buffer = code_buffer[:, -(dit_left_padding + dit_right_padding):]
dit_mel_len += generated_mel.size(-1)
dit_start_index = dit_mel_len - dit_left_context * 2
# bigvgan
mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1)
waveform = self.bigvgan(mel_buffer)
waveform = waveform[vocoder_left_pad * vocoder_upsample_rate: -vocoder_right_pad * vocoder_upsample_rate]
waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1)
vocoder_left_pad = vocoder_left_context
mel_buffer = mel_buffer[:, :, -(vocoder_left_pad + vocoder_right_pad):]
if code_buffer.size(1) > 0:
generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2])
generated_mel = generated_mel[:, :, dit_left_padding * 2:]
mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1)
waveform = self.bigvgan(mel_buffer)
waveform = waveform[vocoder_left_pad * vocoder_upsample_rate:]
waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1)
return waveform_buffer
def export_spk(self):
import MNN.expr as expr
def torch_to_mnn(x):
return expr.const(x.data_ptr(), x.shape)
var_list = []
for key, value in self.speaker_map.items():
for k, v in value.items():
if type(v) is not torch.Tensor:
v = torch.tensor(v)
mnn_var = torch_to_mnn(v.contiguous().float())
mnn_var.name = f'{key}_{k}'
var_list.append(mnn_var)
expr.save(var_list, f'{self.args.dst_path}/spk_dict.mnn')
@spinner_run(f'export token2wav.predit to ')
def export_predit(self, onnx_path):
cond = torch.randn([1, 1, 128], dtype=torch.float32)
spk = torch.randn([1, 1, 192], dtype=torch.float32)
code = torch.ones([1, 256], dtype=torch.int32)
onnx_model = f'{onnx_path}/predit.onnx'
torch.onnx.export(self.dit.preprocess, (cond, spk, code),
onnx_model,
input_names=['cond', 'spk', 'code'],
output_names=['code_embeds', 'rope', 'mask'],
dynamic_axes={
"code": { 1: "size" },
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
@spinner_run(f'export token2wav.dit to ')
def export_dit(self, onnx_path):
x = torch.randn([1, 512, 80], dtype=torch.float32)
code_embeds = torch.randn([1, 512, 1024], dtype=torch.float32)
rope = torch.randn([2, 1, 512, 1, 64], dtype=torch.float32)
mask = torch.ones([1, 1, 512, 512], dtype=torch.int32)
time = torch.tensor([0.0])
onnx_model = f'{onnx_path}/dit.onnx'
torch.onnx.export(self.dit, (x, code_embeds, rope, mask, time),
onnx_model,
input_names=['x', 'code_embeds', 'rope', 'mask', 'time'],
output_names=['mel'],
dynamic_axes={
"x": { 1: "size" },
"code_embeds": { 1: "size" },
"rope": { 2: "size" },
"mask": { 2: "size", 3: "size" },
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
@spinner_run(f'export token2wav.bigvgan to ')
def export_bigvgan(self, onnx_path):
generated_mel = torch.randn([1, 80, 512], dtype=torch.float32)
onnx_model = f'{onnx_path}/bigvgan.onnx'
torch.onnx.export(self.bigvgan, (generated_mel),
onnx_model,
input_names=['generated_mel'],
output_names=['waveform'],
dynamic_axes={
"generated_mel": { 2: "size" },
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
def export(self, onnx_path):
self.export_spk()
predit = self.export_predit(onnx_path)
dit = self.export_dit(onnx_path)
bigvgan = self.export_bigvgan(onnx_path)
return predit, dit, bigvgan