def gaudi_esmfolding_trunk_forward()

in optimum/habana/transformers/models/esm/modeling_esmfold.py [0:0]


def gaudi_esmfolding_trunk_forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
    """
    Inputs:
        seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
        x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues

    Output:
        predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object

    Copied from EsmFoldingTrunk.forward:
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/modeling_esmfold.py
    The change is:
    - Add extra mark_step in trunk_iter for each block.
    """

    device = seq_feats.device
    s_s_0 = seq_feats
    s_z_0 = pair_feats

    if no_recycles is None:
        no_recycles = self.config.max_recycles
    else:
        if no_recycles < 0:
            raise ValueError("Number of recycles must not be negative.")
        no_recycles += 1  # First 'recycle' is just the standard forward pass through the model.

    def trunk_iter(s, z, residx, mask):
        z = z + self.pairwise_positional_embedding(residx, mask=mask)

        for block in self.blocks:
            s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
            if s.device.type == "hpu":
                import habana_frameworks.torch.core as htcore

                htcore.mark_step()
        return s, z

    s_s = s_s_0
    s_z = s_z_0
    recycle_s = torch.zeros_like(s_s)
    recycle_z = torch.zeros_like(s_z)
    recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)

    for recycle_idx in range(no_recycles):
        with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
            # === Recycling ===
            recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
            recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
            recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)

            s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)

            # === Structure module ===
            structure = self.structure_module(
                {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
                true_aa,
                mask.float(),
            )

            recycle_s = s_s
            recycle_z = s_z
            # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
            recycle_bins = self.distogram(
                structure["positions"][-1][:, :, :3],
                3.375,
                21.375,
                self.recycle_bins,
            )

    structure["s_s"] = s_s
    structure["s_z"] = s_z

    return structure