def rtf_power()

in torchaudio/functional/functional.py [0:0]


def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
    r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.

    Args:
        psd_s (Tensor): The complex-valued covariance matrix of target speech.
            Tensor of dimension `(..., freq, channel, channel)`
        psd_n (Tensor): The complex-valued covariance matrix of noise.
            Tensor of dimension `(..., freq, channel, channel)`
        reference_channel (int or Tensor): Indicate the reference channel.
            If the dtype is ``int``, it represent the reference channel index.
            If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
            is one-hot.
        n_iter (int): number of iterations in power method. (Default: ``3``)

    Returns:
        Tensor: the estimated complex-valued RTF of target speech
        Tensor of dimension `(..., freq, channel)`
    """
    assert n_iter > 0, "The number of iteration must be greater than 0."
    # phi is regarded as the first iteration
    phi = torch.linalg.solve(psd_n, psd_s)  # psd_n.inv() @ psd_s
    if torch.jit.isinstance(reference_channel, int):
        rtf = phi[..., reference_channel]
    elif torch.jit.isinstance(reference_channel, Tensor):
        reference_channel = reference_channel.to(psd_n.dtype)
        rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
    else:
        raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
    rtf = rtf.unsqueeze(-1)  # (..., freq, channel, 1)
    if n_iter >= 2:
        # The number of iterations in the for loop is `n_iter - 2`
        # because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
        # two iterations.
        for _ in range(n_iter - 2):
            rtf = torch.matmul(phi, rtf)
        rtf = torch.matmul(psd_s, rtf)
    else:
        # if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
        # which is psd_n @ phi @ ref_channel
        rtf = torch.matmul(psd_n, rtf)
    return rtf.squeeze(-1)