def patched_build_delay_pattern_mask()

in optimum/exporters/onnx/model_patcher.py [0:0]


def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None):
    # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
    input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
    bsz, num_codebooks, seq_len = input_ids.shape

    max_length = max_length if max_length is not None else self.generation_config.max_length
    input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1

    channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
    # we only apply the mask if we have a large enough seq len - otherwise we return as is
    if max_length < 2 * channel_codebooks - 1:
        raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.")

    # fill the shifted ids with the prompt entries, offset by the codebook idx
    for codebook in range(channel_codebooks):
        if self.config.audio_channels == 1:
            # mono channel - loop over the codebooks one-by-one
            input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
        else:
            # left/right channels are interleaved in the generated codebooks, so handle one then the other
            input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
            input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]

    # construct a pattern mask that indicates the positions of padding tokens for each codebook
    # first fill the upper triangular part (the EOS padding)
    # NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
    # Using int8 leads to `Could not find an implementation for Where`
    delay_pattern = triu_onnx(
        torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1
    )

    # NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
    # Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here

    # then fill the lower triangular part (the BOS padding)
    delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))
    delay_pattern = delay_pattern.to(torch.bool)

    if self.config.audio_channels == 2:
        # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
        delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

    mask = ~delay_pattern.to(input_ids.device)
    input_ids = mask * input_ids_shifted + ~mask * pad_token_id

    # find the first position to start generating - this is the first place we have the -1 token
    # and will always be in the first codebook (since it has no codebook offset)
    first_codebook_ids = input_ids[:, 0, :]
    start_ids = (first_codebook_ids == -1).nonzero()[:, 1]

    # TODO: Is this OK?
    first_start_id = start_ids.min()

    # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
    pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
    input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
    return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask}