in optimum/habana/transformers/models/esm/modeling_esmfold.py [0:0]
def gaudi_esmfolding_trunk_forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
"""
Inputs:
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
Output:
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
Copied from EsmFoldingTrunk.forward:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/modeling_esmfold.py
The change is:
- Add extra mark_step in trunk_iter for each block.
"""
device = seq_feats.device
s_s_0 = seq_feats
s_z_0 = pair_feats
if no_recycles is None:
no_recycles = self.config.max_recycles
else:
if no_recycles < 0:
raise ValueError("Number of recycles must not be negative.")
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
def trunk_iter(s, z, residx, mask):
z = z + self.pairwise_positional_embedding(residx, mask=mask)
for block in self.blocks:
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
if s.device.type == "hpu":
import habana_frameworks.torch.core as htcore
htcore.mark_step()
return s, z
s_s = s_s_0
s_z = s_z_0
recycle_s = torch.zeros_like(s_s)
recycle_z = torch.zeros_like(s_z)
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
# === Structure module ===
structure = self.structure_module(
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
true_aa,
mask.float(),
)
recycle_s = s_s
recycle_z = s_z
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
recycle_bins = self.distogram(
structure["positions"][-1][:, :, :3],
3.375,
21.375,
self.recycle_bins,
)
structure["s_s"] = s_s
structure["s_z"] = s_z
return structure