optimum/habana/transformers/models/esm/modeling_esmfold.py (205 lines of code) (raw):

# coding=utf-8 # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from transformers.models.esm.modeling_esmfold import EsmForProteinFoldingOutput, categorical_lddt from transformers.models.esm.openfold_utils import ( compute_predicted_aligned_error, compute_tm, make_atom14_masks, ) from transformers.utils import ( ContextManagers, ) def gaudi_esmfolding_trunk_forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): """ Inputs: seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues Output: predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object Copied from EsmFoldingTrunk.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/modeling_esmfold.py The change is: - Add extra mark_step in trunk_iter for each block. """ device = seq_feats.device s_s_0 = seq_feats s_z_0 = pair_feats if no_recycles is None: no_recycles = self.config.max_recycles else: if no_recycles < 0: raise ValueError("Number of recycles must not be negative.") no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. def trunk_iter(s, z, residx, mask): z = z + self.pairwise_positional_embedding(residx, mask=mask) for block in self.blocks: s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) if s.device.type == "hpu": import habana_frameworks.torch.core as htcore htcore.mark_step() return s, z s_s = s_s_0 s_z = s_z_0 recycle_s = torch.zeros_like(s_s) recycle_z = torch.zeros_like(s_z) recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) for recycle_idx in range(no_recycles): with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): # === Recycling === recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device) recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device) recycle_z += self.recycle_disto(recycle_bins.detach()).to(device) s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) # === Structure module === structure = self.structure_module( {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, true_aa, mask.float(), ) recycle_s = s_s recycle_z = s_z # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. recycle_bins = self.distogram( structure["positions"][-1][:, :, :3], 3.375, 21.375, self.recycle_bins, ) structure["s_s"] = s_s structure["s_z"] = s_z return structure def gaudi_esm_for_protein_folding_forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, masking_pattern: Optional[torch.Tensor] = None, num_recycles: Optional[int] = None, ) -> EsmForProteinFoldingOutput: r""" Returns: Example: ```python >>> from transformers import AutoTokenizer, EsmForProteinFolding >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide >>> outputs = model(**inputs) >>> folded_positions = outputs.positions ``` Copied from EsmForProteinFolding.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/modeling_esmfold.py The change is: - rewrite (softmax().unsqueeze() @ esm_s).squeeze() with equivalent but less dims algorithm on HPU. """ cfg = self.config.esmfold_config aa = input_ids # B x L B = aa.shape[0] L = aa.shape[1] device = input_ids.device if attention_mask is None: attention_mask = torch.ones_like(aa, device=device) if position_ids is None: position_ids = torch.arange(L, device=device).expand_as(input_ids) # === ESM === esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) if masking_pattern is not None: masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) else: masked_aa = aa mlm_targets = None # We get sequence and pair representations from whatever version of ESM / # configuration we are using. The sequence representation esm_s is always # present. The pair embedding esm_z may be present depending on the # configuration of the model. If esm_z is not used by the model then it # is returned as None here. esm_s = self.compute_language_model_representations(esmaa) # Convert esm_s and esm_z, if present, to the precision used by the trunk and # the structure module. These tensors may be a lower precision if, for example, # we're running the language model in fp16 precision. esm_s = esm_s.to(self.esm_s_combine.dtype) if cfg.esm_ablate_sequence: esm_s = esm_s * 0 esm_s = esm_s.detach() # === preprocessing === if esm_s.device.type == "hpu": dims = esm_s.shape esm_s = esm_s.reshape(-1, dims[-2], dims[-1]) # combine first 2 dims esm_s = self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s esm_s = esm_s.reshape(dims[0], dims[1], esm_s.shape[-2], esm_s.shape[-1]) # split back 1st dim esm_s = esm_s.squeeze(2) else: esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) s_s_0 = self.esm_s_mlp(esm_s) s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) if self.config.esmfold_config.embed_aa: s_s_0 += self.embedding(masked_aa) structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) # Documenting what we expect: structure = { k: v for k, v in structure.items() if k in [ "s_z", "s_s", "frames", "sidechain_frames", "unnormalized_angles", "angles", "positions", "states", ] } # Add BERT mask for the loss to use, if available. if mlm_targets: structure["mlm_targets"] = mlm_targets disto_logits = self.distogram_head(structure["s_z"]) disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 structure["distogram_logits"] = disto_logits lm_logits = self.lm_head(structure["s_s"]) structure["lm_logits"] = lm_logits structure["aatype"] = aa make_atom14_masks(structure) # Of course, this doesn't respect the true mask because it doesn't know about it... # We're not going to properly mask change of index tensors: # "residx_atom14_to_atom37", # "residx_atom37_to_atom14", for k in [ "atom14_atom_exists", "atom37_atom_exists", ]: structure[k] *= attention_mask.unsqueeze(-1) structure["residue_index"] = position_ids lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) structure["lddt_head"] = lddt_head plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) structure["plddt"] = plddt ptm_logits = self.ptm_head(structure["s_z"]) structure["ptm_logits"] = ptm_logits structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) return EsmForProteinFoldingOutput(**structure) def gaudi_rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Applies a rotation to a vector. Written out by hand to avoid transfer to avoid AMP downcasting. Args: r: [*, 3, 3] rotation matrices t: [*, 3] coordinate tensors Returns: [*, 3] rotated coordinates Copied from rot_vec_mul: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/openfold_utils/rigid_utils.py The change is: - Using matmul when possible on HPU to get better performance. """ # Do matmal on HPU directly when possible to get better performance. if r.device.type == "hpu": if t.dim() > 5: pass elif t.dim() == 5: # Combine shape[2] and shape[3] on HPU shape_t = t.shape shape_r = r.shape t = t.reshape(shape_t[0], shape_t[1], shape_t[2] * shape_t[3], shape_t[4]) r = r.reshape(shape_r[0], shape_r[1], shape_r[2] * shape_r[3], shape_r[4], shape_r[5]) t = t.unsqueeze(-2) r = r.transpose(-2, -1) out = t @ r shape_out = out.shape out = out.reshape( shape_out[0], shape_out[1], max(shape_r[2], shape_t[2]), max(shape_r[3], shape_t[3]), shape_out[3], shape_out[4], ) out = out.squeeze(-2) return out else: t = t.unsqueeze(-2) r = r.transpose(-2, -1) out = t @ r out = out.squeeze(-2) return out x, y, z = torch.unbind(t, dim=-1) return torch.stack( [ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, ], dim=-1, ) def gaudi_rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Performs matrix multiplication of two rotation matrix tensors. Written out by hand to avoid AMP downcasting. Args: a: [*, 3, 3] left multiplicand b: [*, 3, 3] right multiplicand Returns: The product ab Copied from rot_matmul: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/openfold_utils/rigid_utils.py The change is: - Using matmul when possible on HPU to get better performance. """ # Do matmal on HPU directly when possible to get better performance. if a.device.type == "hpu": if a.shape == b.shape or a.dim() < 5: out = a @ b return out elif a.dim() == 5 and a.shape[2] == 1: # HPU does not handle dim==5 with below broadcast correctly. # a.shape = torch.Size([1, 512, 1, 3, 3]), b.shape = torch.Size([1, 512, 8, 3, 3]) a = a.permute(0, 1, 2, 4, 3) b = b.permute(0, 1, 2, 4, 3) out = b @ a out = out.permute(0, 1, 2, 4, 3) return out else: pass def row_mul(i: int) -> torch.Tensor: return torch.stack( [ a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0], a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1], a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2], ], dim=-1, ) return torch.stack( [ row_mul(0), row_mul(1), row_mul(2), ], dim=-2, )