transformers/llm/export/utils/token2wav.py (412 lines of code) (raw):

import os import torch import torch.nn.functional as F torch.set_printoptions(precision=4, sci_mode=False) from .model_mapper import ModelMapper from .transformers import Rotary, Embedding, Decoder, Attention from .spinner import spinner_run class Token2Wav(torch.nn.Module): def __init__(self,token2wav, base): super().__init__() self.args = base.args self.token2wav = token2wav.float() self.config = base.config self.llm_config = base.llm_config self.rope_ratio = 1.0 self.quant_bit = 8 self.load() def load(self): raise NotImplementedError def add_token_embeds(self, thinker_embeds): raise NotImplementedError def add_hidden_states(self, thinker_hidden_states): raise NotImplementedError def add_generate_ids(self, token_id): raise NotImplementedError def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None): raise NotImplementedError def export(self, onnx_path): raise NotImplementedError class UpSample1d(torch.nn.Module): def __init__(self, upsample, channel): super().__init__() self.ratio = upsample.ratio self.stride = upsample.stride self.pad = upsample.pad self.pad_left = upsample.pad_left self.pad_right = upsample.pad_right self.filter = upsample.filter.expand(channel, -1, -1).clone() self.channel = channel def forward(self, x): x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d(x, self.filter, stride=self.stride, groups=self.channel) x = x[..., self.pad_left : -self.pad_right] return x class DownSample1d(torch.nn.Module): def __init__(self, downsample, channel): super().__init__() self.pad_left = downsample.pad_left self.pad_right = downsample.pad_right self.stride = downsample.stride self.filter = downsample.filter.expand(channel, -1, -1).clone() self.channel = channel def forward(self, x): x = F.pad(x, (self.pad_left, self.pad_right), mode="replicate") out = F.conv1d(x, self.filter, stride=self.stride, groups=self.channel) return out class TorchActivation1d(torch.nn.Module): def __init__( self, activation ): super().__init__() self.act = activation.act channel = self.act.in_features self.upsample = UpSample1d(activation.upsample, channel) self.downsample = DownSample1d(activation.downsample, channel) def forward(self, x): x = self.upsample(x) x = self.act(x) x = self.downsample(x) return x # DiT model code class ECAPA_TDNN(torch.nn.Module): def __init__(self, spk_encoder): super().__init__() self.blocks = spk_encoder.blocks self.mfa = spk_encoder.mfa self.asp = spk_encoder.asp self.fc = spk_encoder.fc def forward(self, x): # Minimize transpose for efficiency x = x.transpose(1, 2) xl = [] for layer in self.blocks: x = layer(x) xl.append(x) # Multi-layer feature aggregation x = torch.cat(xl[1:], dim=1) x = self.mfa(x) # Attentive Statistical Pooling x = self.asp(x) # Final linear transformation x = self.fc(x) # x = x.squeeze(-1) # avoid If when export to onnx x = x.permute(0, 2, 1) return x class DitRotary(Rotary): def __init__(self): super().__init__(None) self.model_type = 'dit' self.rope_theta = 10000 self.rotary_dim = 64 self.theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) def forward(self, position_ids): position_ids = position_ids.float().reshape(-1, 1) idx_theta = position_ids * self.theta rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)]) rotary_pos_emb = torch.stack((rotary_pos_emb, rotary_pos_emb), dim=-1) rotary_pos_emb = rotary_pos_emb.reshape(*rotary_pos_emb.shape[:-2], -1) rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1) return rotary_pos_emb @staticmethod def apply_rotary_pos(x, cos, sin): def rotate_half(x): x = x.reshape(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return x.reshape(*x.shape[:-2], -1) x = (x * cos) + (rotate_half(x) * sin) return x import math class DiTAttention(torch.nn.Module): def __init__(self, attn): super().__init__() self.dim = attn.dim self.heads = attn.heads self.inner_dim = attn.inner_dim self.to_q = attn.to_q self.to_k = attn.to_k self.to_v = attn.to_v self.to_out = attn.to_out def forward( self, x, rope=None, mask=None, ) -> torch.Tensor: batch_size = x.shape[0] # `sample` projections. query = self.to_q(x) key = self.to_k(x) value = self.to_v(x) # attention inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim) key = key.view(batch_size, -1, self.heads, head_dim) value = value.view(batch_size, -1, self.heads, head_dim) # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = rope[0], rope[1] first_query = query[:, :, :1, :] first_key = key[:, :, :1, :] other_query = query[:, :, 1:, :] other_key = key[:, :, 1:, :] first_query = DitRotary.apply_rotary_pos(first_query, cos, sin) first_key = DitRotary.apply_rotary_pos(first_key, cos, sin) query = torch.concat([first_query, other_query], dim=2) key = torch.concat([first_key, other_key], dim=2) attention_mask = (~mask) * torch.finfo(torch.float32).min query = query.transpose(1, 2) key = key.permute([0, 2, 3, 1]) value = value.transpose(1, 2) attn_weights = torch.matmul(query, key) / math.sqrt(head_dim) attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value) x = attn_output.transpose(1, 2) # mask. e.g. inference got a batch with different target durations, mask out the padding x = x.reshape(batch_size, -1, self.heads * head_dim) x = x.to(query.dtype) # linear proj x = self.to_out[0](x) # dropout x = self.to_out[1](x) return x class DiTBlock(torch.nn.Module): def __init__(self, block): super().__init__() self.attn_norm = block.attn_norm self.attn = DiTAttention(block.attn) self.attn_ = block.attn self.look_ahead_block = block.look_ahead_block self.look_backward_block = block.look_backward_block self.ff_norm = block.ff_norm self.ff = block.ff def forward(self, x, t, rope=None, block_diff=None): norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) attn_output = self.attn( x=norm, rope=rope, mask=(block_diff >= -float(self.look_backward_block)) & (block_diff <= float(self.look_ahead_block)), ) # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm) x = x + gate_mlp.unsqueeze(1) * ff_output return x class DitPreprocess(torch.nn.Module): def __init__(self, dit): super().__init__() self.code_embed = dit.code_embed self.input_proj = dit.proj_in_other self.rotary_embed = DitRotary() self.block_size = 24 def forward(self, cond, spk, code): max_duration = code.shape[1] * 2 spk = spk.repeat(1, max_duration, 1) cond = cond.repeat(1, max_duration, 1) code_embed = self.code_embed(code) input_embeds = torch.cat((cond, code_embed, spk), dim=-1) code_embeds = self.input_proj(input_embeds) position_ids = torch.arange(max_duration) rope = self.rotary_embed(position_ids) block_indices = position_ids // self.block_size block_i = block_indices.unsqueeze(1) block_j = block_indices.unsqueeze(0) block_diff = block_j - block_i mask = block_diff.reshape(1, 1, max_duration, max_duration) return code_embeds, rope, mask class DitWrapper(torch.nn.Module): def __init__(self, dit): super().__init__() self.dit = dit self.cfg = False self.time_embed = dit.time_embed self.code_embed = dit.text_embed self.rotary_embed = DitRotary() self.transformer_blocks = torch.nn.ModuleList() for i in range(len(dit.transformer_blocks)): self.transformer_blocks.append(DiTBlock(dit.transformer_blocks[i])) self._create_block_diff = dit._create_block_diff self.norm_out = dit.norm_out self.proj_out = dit.proj_out proj_in = dit.input_embed.proj oc, ic = proj_in.weight.shape x_ic = 80 other_ic = ic - x_ic self.proj_in_x = torch.nn.Linear(x_ic, oc) self.proj_in_x.weight.data = proj_in.weight[:, :x_ic] self.proj_in_x.bias = None self.proj_in_other = torch.nn.Linear(other_ic, oc) self.proj_in_other.weight.data = proj_in.weight[:, x_ic:] self.proj_in_other.bias = proj_in.bias self.spk_encoder = ECAPA_TDNN(dit.input_embed.spk_encoder) self.preprocess = DitPreprocess(self) def spk_encode(self, spk): return self.spk_encoder(spk) def forward(self, x, code_embeds, rope, mask, time): t = self.time_embed(time) hidden = self.proj_in_x(x) + code_embeds for block in self.transformer_blocks: hidden = block(hidden, t, rope=rope, block_diff=mask) hidden = self.norm_out(hidden, t) output = self.proj_out(hidden) return output # end class Qwen2_5OmniToken2Wav(Token2Wav): def __init__(self, token2wav, base): super().__init__(token2wav, base) def load(self): self.dit = self.token2wav.code2wav_dit_model self.bigvgan = self.token2wav.code2wav_bigvgan_model # some code change for export self.dit = DitWrapper(self.dit) # bigvgan.resblocks.activations.up/downsample contain conv weight channel by input for i in range(len(self.bigvgan.resblocks)): for j in range(len(self.bigvgan.resblocks[i].activations)): old_act = self.bigvgan.resblocks[i].activations[j] self.bigvgan.resblocks[i].activations[j] = TorchActivation1d(old_act) self.bigvgan.activation_post = TorchActivation1d(self.bigvgan.activation_post) # spk path = os.path.join(self.args.path, 'spk_dict.pt') self.speaker_map = {} for key, value in torch.load(path).items(): spk = value["cond"].float() cond = value['ref_mel'].float() value.pop("ref_mel", None) value['spk'] = spk.unsqueeze(1) value['cond'] =self.dit.spk_encode(cond) self.speaker_map[key] = value spk = "Chelsie" self.speaker_params = self.speaker_map[spk] def dit_forward(self, code, initial_noise = None): spk = self.speaker_params["spk"].float() cond = self.speaker_params["cond"].float() max_duration = code.shape[1] * 2 code_embeds, rope, mask = self.dit.preprocess(cond, spk, code) def func(t, x): pred = self.dit(x=x, code_embeds=code_embeds, rope=rope, mask=mask, time=torch.tensor([t])) return pred steps = 5 t = torch.linspace(0, 1, steps, dtype=cond.dtype) t = 1 - torch.cos(torch.pi / 2 * t) if initial_noise is None: torch.manual_seed(42) y0 = torch.randn([1, max_duration, 80], dtype=cond.dtype) else: y0 = initial_noise.clone() for t0, t1 in zip(t[:-1], t[1:]): dt = t1 - t0 k1 = func(t0, y0) k2 = func(t0 + dt * 1/3, y0 + dt * k1 * 1/3) k3 = func(t0 + dt * 2/3, y0 + dt * (k2 - k1 * 2/3)) k4 = func(t1, y0 + dt * (k1 - k2 + k3)) dy = (k1 + 3 * (k2 + k3) + k4) * dt * 0.125 y0 += dy generated_mel = y0.permute(0, 2, 1) # print('generated_mel = ', generated_mel, generated_mel.shape) # print('generated_mel.shape = ', generated_mel.shape) return generated_mel @torch.no_grad() def generate(self, code): generated_mel = self.dit_forward(code) waveform = self.bigvgan(generated_mel) return waveform @torch.no_grad() def generate_stream(self, code): # Defeine dit streaming parameters dit_chunk_size = 48 dit_left_context = 24 dit_right_context = 12 dit_left_padding = 0 dit_right_padding = dit_right_context dit_start_index = 0 dit_mel_len = 0 # Define vocoder streaming parameters vocoder_left_context = 10 vocoder_right_context = 10 vocoder_left_pad = 0 vocoder_right_pad = vocoder_right_context vocoder_upsample_rate = 240 torch.manual_seed(42) initial_noise = torch.randn([1, 30000, 80], dtype=torch.float32) code_buffer = torch.full((1, 0), 0, dtype=torch.long, device=code.device) mel_buffer = torch.full((1, 80, 0), 0, dtype=torch.float32, device=code.device) waveform_buffer = torch.full((0,), 0, dtype=torch.float32) for next_code in code[0]: code_buffer = torch.cat([code_buffer, next_code.reshape(1, 1)], dim=1) if code_buffer.size(1) == dit_left_padding + dit_chunk_size + dit_right_padding: # dit generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2]) generated_mel = generated_mel[:, :, dit_left_padding * 2: -dit_right_padding * 2] dit_left_padding = dit_left_context code_buffer = code_buffer[:, -(dit_left_padding + dit_right_padding):] dit_mel_len += generated_mel.size(-1) dit_start_index = dit_mel_len - dit_left_context * 2 # bigvgan mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1) waveform = self.bigvgan(mel_buffer) waveform = waveform[vocoder_left_pad * vocoder_upsample_rate: -vocoder_right_pad * vocoder_upsample_rate] waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1) vocoder_left_pad = vocoder_left_context mel_buffer = mel_buffer[:, :, -(vocoder_left_pad + vocoder_right_pad):] if code_buffer.size(1) > 0: generated_mel = self.dit_forward(code_buffer, initial_noise[:, dit_start_index: dit_start_index + code_buffer.size(1) * 2]) generated_mel = generated_mel[:, :, dit_left_padding * 2:] mel_buffer = torch.cat([mel_buffer, generated_mel], dim=-1) waveform = self.bigvgan(mel_buffer) waveform = waveform[vocoder_left_pad * vocoder_upsample_rate:] waveform_buffer = torch.cat([waveform_buffer, waveform], dim=-1) return waveform_buffer def export_spk(self): import MNN.expr as expr def torch_to_mnn(x): return expr.const(x.data_ptr(), x.shape) var_list = [] for key, value in self.speaker_map.items(): for k, v in value.items(): if type(v) is not torch.Tensor: v = torch.tensor(v) mnn_var = torch_to_mnn(v.contiguous().float()) mnn_var.name = f'{key}_{k}' var_list.append(mnn_var) expr.save(var_list, f'{self.args.dst_path}/spk_dict.mnn') @spinner_run(f'export token2wav.predit to ') def export_predit(self, onnx_path): cond = torch.randn([1, 1, 128], dtype=torch.float32) spk = torch.randn([1, 1, 192], dtype=torch.float32) code = torch.ones([1, 256], dtype=torch.int32) onnx_model = f'{onnx_path}/predit.onnx' torch.onnx.export(self.dit.preprocess, (cond, spk, code), onnx_model, input_names=['cond', 'spk', 'code'], output_names=['code_embeds', 'rope', 'mask'], dynamic_axes={ "code": { 1: "size" }, }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model @spinner_run(f'export token2wav.dit to ') def export_dit(self, onnx_path): x = torch.randn([1, 512, 80], dtype=torch.float32) code_embeds = torch.randn([1, 512, 1024], dtype=torch.float32) rope = torch.randn([2, 1, 512, 1, 64], dtype=torch.float32) mask = torch.ones([1, 1, 512, 512], dtype=torch.int32) time = torch.tensor([0.0]) onnx_model = f'{onnx_path}/dit.onnx' torch.onnx.export(self.dit, (x, code_embeds, rope, mask, time), onnx_model, input_names=['x', 'code_embeds', 'rope', 'mask', 'time'], output_names=['mel'], dynamic_axes={ "x": { 1: "size" }, "code_embeds": { 1: "size" }, "rope": { 2: "size" }, "mask": { 2: "size", 3: "size" }, }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model @spinner_run(f'export token2wav.bigvgan to ') def export_bigvgan(self, onnx_path): generated_mel = torch.randn([1, 80, 512], dtype=torch.float32) onnx_model = f'{onnx_path}/bigvgan.onnx' torch.onnx.export(self.bigvgan, (generated_mel), onnx_model, input_names=['generated_mel'], output_names=['waveform'], dynamic_axes={ "generated_mel": { 2: "size" }, }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model def export(self, onnx_path): self.export_spk() predit = self.export_predit(onnx_path) dit = self.export_dit(onnx_path) bigvgan = self.export_bigvgan(onnx_path) return predit, dit, bigvgan