in backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py [0:0]
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=f"{prefix}.layers.0",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
)
)
else:
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(f"{prefix}.layers.{last_layer_id}"),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads