in src/diarizers/models/pyannet.py [0:0]
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Parameters
----------
waveforms : (batch, channel, sample)
Returns
-------
scores : (batch, frame, classes)
"""
outputs = self.sincnet(waveforms)
if self.hparams.lstm["monolithic"]:
outputs, _ = self.lstm(rearrange(outputs, "batch feature frame -> batch frame feature"))
else:
outputs = rearrange(outputs, "batch feature frame -> batch frame feature")
for i, lstm in enumerate(self.lstm):
outputs, _ = lstm(outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
outputs = self.dropout(outputs)
if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))
return self.activation(self.classifier(outputs))