optimum/habana/transformers/models/falcon_mamba/modeling_falcon_mamba.py (111 lines of code) (raw):
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch FALCONMAMBA model."""
from typing import Optional, Tuple, Union
import habana_frameworks.torch.core as htcore
import torch
from transformers.cache_utils import MambaCache
from transformers.models.falcon_mamba.modeling_falcon_mamba import FalconMambaOutput
from transformers.utils import (
logging,
)
logger = logging.get_logger(__name__)
"""
Copys from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py#L635
The only differences are:
- Use the mark_step function to reduce the graph compiling time.
"""
def gaudi_FalconMambaModel_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if use_cache:
if cache_params is None:
cache_params = MambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
elif cache_position is None:
# cases when we do manual forward instead of using `model.generate` which will initiate
# `cache_position` and makes sure it is not None, throw error here instead of doing some
# hack to conjecture the current cache position
raise ValueError(
"You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
"you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
"be initialized for you automatically"
)
else:
cache_params = None
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if lazy_mode:
htcore.mark_step()
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return FalconMambaOutput(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
"""
Copys from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py#L762
The only differences are:
- Use the torch.index_select function to replace the slicing operation of Line 51
"""
def gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation(
self,
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs,
):
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
if cache_position[0] > 0:
# input_ids = input_ids[:, -1].unsqueeze(-1)
idx = torch.tensor([input_ids.size(1) - 1], device=input_ids.device)
input_ids = torch.index_select(input_ids, 1, idx)
if attention_mask is not None:
attention_mask = None
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
"attention_mask": attention_mask,
}
)
return model_inputs