def get_layer_names_in_sft_format()

in picotron/checkpoint.py [0:0]


    def get_layer_names_in_sft_format(self):
        """Get layer names in safetensors format based on model's layer distribution."""
        decoder_components = [
            "input_layernorm",
            "mlp.down_proj",
            "mlp.gate_proj",
            "mlp.up_proj",
            "post_attention_layernorm",
            "self_attn.k_proj",
            "self_attn.o_proj",
            "self_attn.q_proj",
            "self_attn.v_proj",
        ]
        
        # Generate base layer names
        layer_names = []
        if isinstance(self.model, PipelineParallel):
            base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
        else:
            base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
        
        for layer in base_names:
            for component in decoder_components:
                layer_names.append(f"{layer}.{component}.weight")
       
        # Add special layers based on pipeline stage or non-PP case
        # NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
        if isinstance(self.model, PipelineParallel):
            if pgm.process_group_manager.pp_is_first_stage:
                layer_names.insert(0, "model.embed_tokens.weight")
            elif pgm.process_group_manager.pp_is_last_stage:
                layer_names.extend(["model.norm.weight"])
        else:
            layer_names.insert(0, "model.embed_tokens.weight")
            layer_names.extend(["model.norm.weight"])

        return layer_names