in server/text_generation_server/models/custom_modeling/mpt_modeling.py [0:0]
def __init__(self, prefix: str, config, weights):
# config._validate_config()
super().__init__(config)
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.n_heads = config.n_heads
self.attn_impl = config.attn_config.attn_impl
self.prefix_lm = config.attn_config.prefix_lm
self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed":
# TODO: reimplement mixed device initialization
# dist.get_local_rank() == 0:
if True:
config.init_device = "cpu"
else:
config.init_device = "meta"
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
raise NotImplementedError(
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
)
if config.norm_type.lower() != "low_precision_layernorm":
raise NotImplementedError(
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
)
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi:
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers)
]
)
if config.no_bias:
self.norm_f = nn.LayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
else:
self.norm_f = nn.LayerNorm.load(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False
self.attn_bias = None
self.attn_bias_shape = attn_bias_shape(
self.attn_impl,
config.n_heads,
config.max_seq_len,
self.alibi,
prefix_lm=self.prefix_lm,
causal=self.is_causal,
use_sequence_id=self.attn_uses_sequence_id,
)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
if config.verbose:
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
module.register_parameter("bias", None)
if hasattr(self.config, "verbose"):
if config.verbose and config.verbose > 2:
print(self)
if "verbose" not in self.config.init_config:
self.config.init_config["verbose"] = self.config.verbose
if self.config.init_config["verbose"] > 1:
init_fn_name = self.config.init_config["name"]
warnings.warn(f"Using {init_fn_name} initialization.")