def rotate_dihedral()

in mmcif_utils.py [0:0]


def rotate_dihedral(angles, par, child, pos, pos_exist, chis, chi_valid):
    """Rotate a protein representation by a set of dihedral angles:
        N represents the number of amino acids in the batch, 20 is the number of atoms.
        angles: N x 20 set of angles to rotate each atom by
        par: A N x 20 encoding of the relative offset of the parent of each atom. For example,
            the amino acid glycine would be represented at [-18 -1 -1 -1 0, ...]
        child: A N x 20 encoding of the child of each atom. For example, the amino acid glycine
            would be represented as [1 1 18 0 0 0 ..]
        pos_exist: A N x 20 mask encoding of which atoms are valid for each amino acid so for
            example the amino acid glycine would be represented as [1 1 1 1 0 0 ...]
        chis: A N x 20 representation of the existing chi angles
        chi_valid: A N x 5 mask encoding which chi angles are valid, so for example glycine would
        be represented as [0 0 0 0 0]
        """

    angles = angles / 180 * np.pi
    chis = chis / 180 * np.pi
    pos_orig = pos
    pos = pos.copy()

    for i in range(4):
        # There are a maximum of 5 chi angles
        p2 = pos[:, 4 + i]
        index = np.tile(4 + i, (pos.shape[0], 1)) + par[:, 4 + i : 5 + i]
        # print("index, pos shape ", index.shape, pos.shape)
        p1 = np.take_along_axis(pos, index[:, :, None], axis=1)[:, 0, :]

        rot_angle = chis[:, i] - angles[:, 4 + i]

        diff_vec = p2 - p1
        diff_vec_normalize = diff_vec / (np.linalg.norm(diff_vec, axis=1, keepdims=True) + 1e-10)

        # Rotate all subsequent points by the rotamor angle with the defined line where normalize on the origin
        rot_points = pos[:, 5 + i :].copy() - p1[:, None, :]

        par_points = (rot_points * diff_vec_normalize[:, None, :]).sum(
            axis=2, keepdims=True
        ) * diff_vec_normalize[:, None, :]
        perp_points = rot_points - par_points

        perp_points_norm = np.linalg.norm(perp_points, axis=2, keepdims=True) + 1e-10
        perp_points_normalize = perp_points / perp_points_norm

        a3 = np.cross(diff_vec_normalize[:, None, :], perp_points_normalize)

        rot_points = (
            perp_points * np.cos(rot_angle)[:, None, None]
            + np.sin(rot_angle)[:, None, None] * a3 * perp_points_norm
            + par_points
            + p1[:, None, :]
        )

        rot_points[np.isnan(rot_points)] = 10000

        # Only set the points that vald chi angles
        first_term = rot_points * chi_valid[:, i : i + 1, None]
        second_term = pos[:, 5 + i :] * (1 - chi_valid[:, i : i + 1, None])

        pos[:, 5 + i :] = first_term + second_term

    return pos