def forward()

in benchmarks/rnnt/ootb/inference/pytorch/parts/features.py [0:0]


    def forward(self, inp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x, seq_len = inp

        dtype = x.dtype

        seq_len = self.get_seq_len(seq_len)

        # dither
        if self.dither > 0 and not self.use_deterministic_dithering:
            x += self.dither * torch.randn_like(x)

        # do preemphasis
        # Ideally, we would mask immediately after this... Ugh :(
        if self.preemph is not None:
            x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]),
                          dim=1)

        # do stft
        x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
                       win_length=self.win_length,
                       center=True, window=self.window.to(dtype=torch.float))

        # get power spectrum
        x = x.pow(2).sum(-1)

        if self.dither > 0 and self.use_deterministic_dithering:
            x = x + self.dither ** 2
        # dot with filterbank energies
        x = torch.matmul(self.fb.to(x.dtype), x)

        # log features if required
        if self.log:
            x = torch.log(x + 1e-20)

        # frame splicing if required
        if self.frame_splicing > 1:
            seq = [x]
            for n in range(1, self.frame_splicing):
                tmp = torch.zeros_like(x)
                tmp[:, :, :-n] = x[:, :, n:]
                seq.append(tmp)
            x = torch.cat(seq, dim=1)[:, :, ::self.frame_splicing]

        # normalize if required
        constant = 1e-5
        if self.normalize == "per_feature":
            x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
                                 device=x.device)
            x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
                                device=x.device)
            for i in range(x.shape[0]):
                x_mean[i, :] = x[i, :, :seq_len[i]].mean(dim=1)
                x_std[i, :] = x[i, :, :seq_len[i]].std(dim=1)
                # make sure x_std is not zero
                x_std += constant
            x = (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
        elif self.normalize == "all_features":
            x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
            x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
            for i in range(x.shape[0]):
                x_mean[i] = x[i, :, :seq_len[i].item()].mean()
                x_std[i] = x[i, :, :seq_len[i].item()].std()
                # make sure x_std is not zero
                x_std += constant
            x = (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
        else:
            x = x

        # Hmmm... They don't do any masking anymore. Seems concerning!

        # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
        # max_len = x.size(-1)
        x = x[:, :, :seq_len.max()]   # rnnt loss requires lengths to match
        # mask = torch.arange(max_len).to(seq_len.dtype).to(x.device).expand(x.size(0),
        #                                                                   max_len) >= seq_len.unsqueeze(1)

        # x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
        pad_to = self.pad_to
        if pad_to != 0:
            raise NotImplementedError()
        # if pad_to == "max":
        #    x = nn.functional.pad(x, (0, self.max_length - x.size(-1)))
        # elif pad_to > 0:
        #    pad_amt = x.size(-1) % pad_to
        #    if pad_amt != 0:
        #        x = nn.functional.pad(x, (0, pad_to - pad_amt))

        return x.to(dtype)