in torchaudio/functional/functional.py [0:0]
def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``)
Returns:
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
"""
assert n_iter > 0, "The number of iteration must be greater than 0."
# phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
elif torch.jit.isinstance(reference_channel, Tensor):
reference_channel = reference_channel.to(psd_n.dtype)
rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
if n_iter >= 2:
# The number of iterations in the for loop is `n_iter - 2`
# because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
# two iterations.
for _ in range(n_iter - 2):
rtf = torch.matmul(phi, rtf)
rtf = torch.matmul(psd_s, rtf)
else:
# if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
# which is psd_n @ phi @ ref_channel
rtf = torch.matmul(psd_n, rtf)
return rtf.squeeze(-1)