in step7_pipeline_parallel_afab/pipeline_parallel.py [0:0]
def __init__(self, model, config):
super().__init__()
layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution})
self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()