def allreduce_word_embedding_grads()

in chatlearn/models/megatron/lora/layers.py [0:0]


    def allreduce_word_embedding_grads(self, args):
        """
        All-reduce word embedding grads.
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
        """
        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
            mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
                unwrapped_model = self.models[0]
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
                unwrapped_model = self.models[-1]
            else:  # We do not support the interleaved schedule for T5 yet.
                unwrapped_model = self.models[0]

            if hasattr(unwrapped_model, "share_word_embeddings"):
                from chatlearn.utils.megatron_import_helper import DistributedDataParallel as LocalDDP # pylint: disable=import-outside-toplevel
                unwrapped_model = unwrap_model(
                    unwrapped_model, (torchDDP, LocalDDP, Float16Module))
                if unwrapped_model.share_word_embeddings:
                    word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                    if word_embeddings_weight.requires_grad:
                        if args.DDP_impl == 'local':
                            grad = word_embeddings_weight.main_grad
                        else:
                            grad = word_embeddings_weight.grad
                        torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
            elif hasattr(unwrapped_model, "share_embeddings_and_output_weights"):
                unwrapped_model = unwrap_model(unwrapped_model)
                if unwrapped_model.share_embeddings_and_output_weights:
                    weight = unwrapped_model.shared_embedding_or_output_weight()
                    if weight.requires_grad:
                        grad = weight.main_grad
                        torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())