def __init__()

in picotron/pipeline_parallel/pipeline_parallel.py [0:0]


    def __init__(self, model, config):
        super().__init__()
        # Determine which layers should be assigned to this GPU
        self.layer_distribution = self.distribute_layers(config.num_hidden_layers)
        # Only first stage has embedding layer, others use Identity
        self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
        # Assign relevant decoder layers to this GPU
        self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in self.layer_distribution})
        # Only last stage has normalization and projection layers
        self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
        self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()

        self.reset_parameters()