in optimum/exporters/onnx/model_patcher.py [0:0]
def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None):
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * channel_codebooks - 1:
raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.")
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int8 leads to `Could not find an implementation for Where`
delay_pattern = triu_onnx(
torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1
)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))
delay_pattern = delay_pattern.to(torch.bool)
if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
# TODO: Is this OK?
first_start_id = start_ids.min()
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask}