def reset_parameters()

in picotron/pipeline_parallel/pipeline_parallel.py [0:0]


    def reset_parameters(self):
        """Initialize or reset all model parameters for this pipeline stage."""
        if pgm.process_group_manager.pp_is_first_stage:
            self.embedding.reset_parameters()

        for layer in self.decoder_layers.values():
            layer.input_layernorm.reset_parameters()
            layer.attention.reset_parameters()
            layer.post_attention_layernorm.reset_parameters()
            layer.mlp.reset_parameters()

        if pgm.process_group_manager.pp_is_last_stage:
            self.final_norm.reset_parameters()
            self.final_proj.reset_parameters()