def regroup_params_to_sync()

in chatlearn/synchronizer/base.py [0:0]


    def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False):
        """
        :meta private:
        """
        param_data_shape = param_data.shape
        # Regroup these tensors into different tp slices.
        # Output: [tp_slice_0, tp_slice_1, ...]
        # Comment:
        #   src -> dst: [w, h * tp_size] -> tp_size * [w, h]
        #       'self_attention.dense' in QWen and LLama2 legacy
        #       'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
        #       'mlp.linear_fc2' in LLama2 mcore model
        #       'mlp.shared_experts.dense_4h_to_h in QWen-MoE model
        #   src -> dst: [w * tp_size, h] -> tp_size * [w, h]
        #       'mlp.dense_h_to_4h' in QWen and LLama2 legacy
        #       'mlp.linear_fc1' in LLama2 mcore model
        #       'mlp.w1' in QWen model only for vLLM backend
        #       'mlp.shared_experts.dense_h_to_4h in QWen-MoE model
        if (
            "self_attention.dense" in name
            or "mlp.dense_4h_to_h" in name
            or "mlp.linear_fc2" in name
            or "mlp.shared_experts.dense_4h_to_h" in name
        ):
            param_data_list = []
            col_offset = param_data_shape[1] // tp_division
            for idx in range(tp_division):
                start = idx * col_offset
                end =  start + col_offset
                param_data_list.append(param_data[:,start:end])
            param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
            del param_data_list
        if (
            "mlp.dense_h_to_4h" in name
            or "mlp.linear_fc1" in name
            or ("mlp.w1" in name and self.concat_params_dict is not None)
            or "mlp.shared_experts.dense_h_to_4h" in name
        ):
            param_data_list = []
            row_offset = param_data_shape[0] // tp_division // 2
            for idx in range(tp_division):
                w1_start = idx * row_offset
                w1_end = w1_start + row_offset
                w2_start = (idx + tp_division) * row_offset
                w2_end = w2_start + row_offset
                param_data_list.append(
                    torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
            param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
            del param_data_list

        #   src -> dst: src_tp_size * [e, h, w] -> dst_tp_size * [e, h, w // tp_division]
        #       'mlp.experts.dense_4h_to_h' in QWen-MoE model when training with QWen+vLLM
        #   src -> dst: src_tp_size * [e, w, h] -> dst_tp_size * [e, w // tp_division, h]
        #       'mlp.experts.dense_h_to_4h in QWen-MoE model when training with QWen+vLLM
        if regroup_routed_experts:
            if "mlp.experts.dense_4h_to_h" in name:
                param_data_list = []
                height_offset = param_data_shape[2] // tp_division
                for height_idx in range(tp_division):
                    height_start = height_idx * height_offset
                    height_end = height_start + height_offset
                    param_data_list.append(param_data[:, :, height_start:height_end])
                param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
                del param_data_list
            elif "mlp.experts.dense_h_to_4h" in name:
                param_data_list = []
                param_data = param_data.reshape(param_data_shape[0] * 2, -1, param_data_shape[2])
                col_offset = param_data.shape[1] // tp_division
                for idx in range(tp_division):
                    start = idx * col_offset
                    end = start + col_offset
                    param_data_list.append(param_data[:, start:end, :])
                param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
                del param_data_list

        return param_data