in benchmarks/rnnt/ootb/inference/pytorch/parts/features.py [0:0]
def forward(self, inp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
x, seq_len = inp
dtype = x.dtype
seq_len = self.get_seq_len(seq_len)
# dither
if self.dither > 0 and not self.use_deterministic_dithering:
x += self.dither * torch.randn_like(x)
# do preemphasis
# Ideally, we would mask immediately after this... Ugh :(
if self.preemph is not None:
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]),
dim=1)
# do stft
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
win_length=self.win_length,
center=True, window=self.window.to(dtype=torch.float))
# get power spectrum
x = x.pow(2).sum(-1)
if self.dither > 0 and self.use_deterministic_dithering:
x = x + self.dither ** 2
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
if self.log:
x = torch.log(x + 1e-20)
# frame splicing if required
if self.frame_splicing > 1:
seq = [x]
for n in range(1, self.frame_splicing):
tmp = torch.zeros_like(x)
tmp[:, :, :-n] = x[:, :, n:]
seq.append(tmp)
x = torch.cat(seq, dim=1)[:, :, ::self.frame_splicing]
# normalize if required
constant = 1e-5
if self.normalize == "per_feature":
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
device=x.device)
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
device=x.device)
for i in range(x.shape[0]):
x_mean[i, :] = x[i, :, :seq_len[i]].mean(dim=1)
x_std[i, :] = x[i, :, :seq_len[i]].std(dim=1)
# make sure x_std is not zero
x_std += constant
x = (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
elif self.normalize == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, :seq_len[i].item()].mean()
x_std[i] = x[i, :, :seq_len[i].item()].std()
# make sure x_std is not zero
x_std += constant
x = (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
else:
x = x
# Hmmm... They don't do any masking anymore. Seems concerning!
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
# max_len = x.size(-1)
x = x[:, :, :seq_len.max()] # rnnt loss requires lengths to match
# mask = torch.arange(max_len).to(seq_len.dtype).to(x.device).expand(x.size(0),
# max_len) >= seq_len.unsqueeze(1)
# x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
pad_to = self.pad_to
if pad_to != 0:
raise NotImplementedError()
# if pad_to == "max":
# x = nn.functional.pad(x, (0, self.max_length - x.size(-1)))
# elif pad_to > 0:
# pad_amt = x.size(-1) % pad_to
# if pad_amt != 0:
# x = nn.functional.pad(x, (0, pad_to - pad_amt))
return x.to(dtype)