def __init__()

in server/text_generation_server/models/custom_modeling/mpt_modeling.py [0:0]


    def __init__(self, prefix: str, config, weights):
        # config._validate_config()
        super().__init__(config)
        self.world_size = weights.process_group.size()
        self.rank = weights.process_group.rank()
        self.n_heads = config.n_heads
        self.attn_impl = config.attn_config.attn_impl
        self.prefix_lm = config.attn_config.prefix_lm
        self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
        self.alibi = config.attn_config.alibi
        self.alibi_bias_max = config.attn_config.alibi_bias_max
        if config.init_device == "mixed":
            # TODO: reimplement mixed device initialization
            # dist.get_local_rank() == 0:
            if True:
                config.init_device = "cpu"
            else:
                config.init_device = "meta"
        if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
            norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
            raise NotImplementedError(
                f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
            )
        if config.norm_type.lower() != "low_precision_layernorm":
            raise NotImplementedError(
                f"Requested norm type ({config.norm_type}) is not implemented within this repo."
            )

        self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)

        if not self.alibi:
            self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
        self.blocks = nn.ModuleList(
            [
                MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
                for i in range(config.n_layers)
            ]
        )
        if config.no_bias:
            self.norm_f = nn.LayerNorm.load_no_bias(
                prefix="transformer.norm_f", weights=weights, eps=EPS
            )
        else:
            self.norm_f = nn.LayerNorm.load(
                prefix="transformer.norm_f", weights=weights, eps=EPS
            )
        self.is_causal = not self.prefix_lm
        self._attn_bias_initialized = False
        self.attn_bias = None
        self.attn_bias_shape = attn_bias_shape(
            self.attn_impl,
            config.n_heads,
            config.max_seq_len,
            self.alibi,
            prefix_lm=self.prefix_lm,
            causal=self.is_causal,
            use_sequence_id=self.attn_uses_sequence_id,
        )
        if config.no_bias:
            for module in self.modules():
                if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
                    if config.verbose:
                        warnings.warn(f"Removing bias ({module.bias}) from {module}.")
                    module.register_parameter("bias", None)
        if hasattr(self.config, "verbose"):
            if config.verbose and config.verbose > 2:
                print(self)
        if "verbose" not in self.config.init_config:
            self.config.init_config["verbose"] = self.config.verbose
        if self.config.init_config["verbose"] > 1:
            init_fn_name = self.config.init_config["name"]
            warnings.warn(f"Using {init_fn_name} initialization.")