def apply_delay_pattern_mask()

in parler_tts/streamer.py [0:0]


    def apply_delay_pattern_mask(self, input_ids):
        # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
        _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
            input_ids[:, :1],
            bos_token_id=self.generation_config.bos_token_id,
            pad_token_id=self.generation_config.decoder_start_token_id,
            max_length=input_ids.shape[-1],
        )
        # apply the pattern mask to the input ids
        input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)

        # revert the pattern delay mask by filtering the pad token id
        mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
        input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)

        if self.use_4dim_audio_codes:
            # append the frame dimension back to the audio codes
            input_ids = input_ids[None, ...]

        # send the input_ids to the correct device
        input_ids = input_ids.to(self.audio_encoder.device)

        decode_sequentially = (
            self.generation_config.bos_token_id in input_ids
            or self.generation_config.pad_token_id in input_ids
            or self.generation_config.eos_token_id in input_ids
        )
        if not decode_sequentially:
            sample = self.audio_encoder.decode(
                audio_codes=input_ids,
                **self.audio_kwargs,
            ).audio_values
            output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)
        else:
            sample = input_ids[:, 0] if self.use_4dim_audio_codes else input_ids[0]
            sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else ((sample >= self.audio_encoder.config.codebook_size).sum(dim=0) == 0)
            sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask]
            sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **self.audio_kwargs).audio_values
            output_values = sample if sample.ndim == 3 else sample.unsqueeze(0)

        audio_values = output_values[0, 0]
        return audio_values.cpu().float().numpy()