in optimum/exporters/openvino/model_patcher.py [0:0]
def __enter__(self):
super().__enter__()
# dbrx has some accuracy issues with bf16 with transformers >= 4.40
# fill causal mask in slightly different way for avoid overflow on some platforms
self._model.transformer._orig_update_causal_mask = self._model.transformer._update_causal_mask
self._model.transformer._update_causal_mask = types.MethodType(
_dbrx_update_causal_mask, self._model.transformer
)
# starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb
patch_rope_sin_cos = is_transformers_version(">=", "4.41.0")
inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq")
dim, base = None, None
if inv_freq is None:
dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim
base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base
max_positions = self._model.config.max_seq_len
if patch_rope_sin_cos:
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
for block in self._model.transformer.blocks:
rotary_emb = block.norm_attn_norm.attn.rotary_emb
# initialize inv_freq for torchscript tracing
if rotary_emb.inv_freq is None:
inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
rotary_emb.inv_freq = inv_freq
if patch_rope_sin_cos:
rotary_emb.register_buffer("embed_positions", embed_positions)
rotary_emb._orig_forward = rotary_emb.forward
rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb)
# remove continue-operator from iteration loop over experts
block.ffn.experts._orig_forward = block.ffn.experts.forward
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)