def sharded_state_dict()

in megatron_patch/model/mixtral/moe/experts.py [0:0]


    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        """
        Maps local expert to global experts.
        The sharded_state_dict for the weight parts are compatible with the SequentialMLP,
        whereas the optimizer states are not due to the limitation from weight transposing.
        That is, for finetuning scenario, the checkpoint is compatible with the SequentialMLP.
        """
        if self.moe_extended_tp:
            raise NotImplementedError(
                'Currently distributed checkpointing is not supported for moe_extended_tp'
            )

        sharded_state_dict = {}
        num_global_experts = (
            parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
        )
        local_expert_indices_offset = (
            parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
        )
        tp_size = parallel_state.get_tensor_model_parallel_world_size()
        tp_rank = parallel_state.get_tensor_model_parallel_rank()

        prepend_axis_num = len(sharded_offsets)
        replica_id = (
            0,
            0,
            parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
        )

        local_ffn_dim_size = (
            self.weight2.numel() // self.num_local_experts // self.config.hidden_size
        )

        @torch.no_grad()
        def sh_ten_build_fn(
            key: str,
            t: torch.Tensor,
            replica_id: ReplicaId,
            flattened_range: Optional[slice],
            tp_axis: int,
            with_glu: bool,
        ):
            # TODO: write a generic implementation to cover both cases with and without GLU
            if tp_axis == 1:
                # weight1
                if with_glu:
                    last_dim_size = local_ffn_dim_size * 2
                else:
                    last_dim_size = local_ffn_dim_size
                real_shape = (self.num_local_experts, self.config.hidden_size, last_dim_size)
            elif tp_axis == 0:
                # weight2
                real_shape = (self.num_local_experts, local_ffn_dim_size, self.config.hidden_size)
                assert with_glu == False
            else:
                raise ValueError("tp_axis should be 0 or 1.")
            if flattened_range is None:
                # weights
                t = t.view(real_shape).transpose(-1, -2)
                # change tp_axis due to the transposing
                tp_axis = 1 - tp_axis
                if with_glu:
                    local_tensors = torch.chunk(t, 2, -2)
                    sub_states = [
                        ShardedTensor.from_rank_offsets(
                            key,
                            local_tensors[0].contiguous(),
                            *sharded_offsets,
                            (
                                prepend_axis_num,
                                parallel_state.get_expert_model_parallel_rank(),
                                parallel_state.get_expert_model_parallel_world_size(),
                            ),
                            (prepend_axis_num + 1, tp_rank, tp_size * 2),
                            replica_id=replica_id,
                            prepend_axis_num=prepend_axis_num,
                        ),
                        ShardedTensor.from_rank_offsets(
                            key,
                            local_tensors[1].contiguous(),
                            *sharded_offsets,
                            (
                                prepend_axis_num,
                                parallel_state.get_expert_model_parallel_rank(),
                                parallel_state.get_expert_model_parallel_world_size(),
                            ),
                            (prepend_axis_num + 1, tp_size + tp_rank, tp_size * 2),
                            replica_id=replica_id,
                            prepend_axis_num=prepend_axis_num,
                        ),
                    ]
                else:
                    sub_states = ShardedTensor.from_rank_offsets(
                        key,
                        t.contiguous(),
                        *sharded_offsets,
                        (
                            prepend_axis_num,
                            parallel_state.get_expert_model_parallel_rank(),
                            parallel_state.get_expert_model_parallel_world_size(),
                        ),
                        (prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
                        replica_id=replica_id,
                        prepend_axis_num=prepend_axis_num,
                    )
            else:
                # flattened optmizer states
                # the non-flattened weight shape is [local_expert_num, hidden_size, ffn_size]
                #
                # For the case without GLU, it is straightforward, we just need to split each
                # expert along the dim-0.
                #
                # For the case with GLU, we need to split the experts along dim-0 and split the
                # two tensors for GLU along dim-2.
                # To split along the non-first dim, we need to chunk the tensor into small pieces,
                # since they belong to different tenors and are interleaved in the flattened space.
                # Refer to the below sketch graph.
                # |................|           |........|........|
                # |............FFFF|           |........|....BBBB|
                # |FFFFFFFFFFFFFFFF|     ->    |AAAAAAAA|BBBBBBBB|
                # |FFFFFFFFFFFFFFFF|           |AAAAAAAA|BBBBBBBB|
                # |FF..............|           |AA......|........|
                # |................|           |........|........|
                #
                # But too many chunks have severe performance issues. We merge these chunks during
                # the save process along with some length information and recover them during the
                # load process.
                assert t.ndim == 1, (key, t.shape)
                if with_glu:
                    non_flat_local_shape = (1, self.config.hidden_size, local_ffn_dim_size)
                    chunk_numel = local_ffn_dim_size
                    sub_states = []
                    start_pos = 0
                    for local_expert_idx in range(self.num_local_experts):
                        first_glu_idx = -1
                        w_start_range = -1
                        v_start_range = -1
                        w_tensors = []
                        v_tensors = []
                        w_lens = []
                        v_lens = []
                        for input_dim_idx in range(self.config.hidden_size):
                            for glu_idx in range(2):
                                local_idx = (
                                    local_expert_idx * self.config.hidden_size * 2
                                    + input_dim_idx * 2
                                    + glu_idx
                                )
                                if (
                                    flattened_range.start < chunk_numel * (local_idx + 1)
                                    and flattened_range.stop > chunk_numel * local_idx
                                ):
                                    if first_glu_idx == -1:
                                        first_glu_idx = glu_idx
                                    end_pos = min(
                                        flattened_range.stop,
                                        chunk_numel * (local_idx + 1) - flattened_range.start,
                                    )
                                    local_tensor = t[start_pos:end_pos]
                                    local_flattened_range = slice(
                                        max(0, flattened_range.start - chunk_numel * local_idx),
                                        min(
                                            chunk_numel,
                                            flattened_range.stop - chunk_numel * local_idx,
                                        ),
                                    )
                                    assert (
                                        len(local_tensor)
                                        == local_flattened_range.stop - local_flattened_range.start
                                    )
                                    start_pos += len(local_tensor)
                                    expert_global_idx = (
                                        local_expert_indices_offset + local_expert_idx
                                    )
                                    if glu_idx == 0:
                                        w_tensors.append(local_tensor)
                                        w_lens.append(len(local_tensor))
                                        if w_start_range == -1:
                                            w_start_range = max(
                                                0, flattened_range.start - chunk_numel * local_idx
                                            )
                                    else:
                                        v_tensors.append(local_tensor)
                                        v_lens.append(len(local_tensor))
                                        if v_start_range == -1:
                                            v_start_range = max(
                                                0, flattened_range.start - chunk_numel * local_idx
                                            )
                        sub_states.append(
                            {
                                'w_tensors': ShardedTensor.from_rank_offsets_flat(
                                    key,
                                    (
                                        torch.cat(w_tensors, -1)
                                        if len(w_tensors) > 0
                                        else torch.Tensor()
                                    ),
                                    non_flat_local_shape,
                                    *sharded_offsets,
                                    (prepend_axis_num, expert_global_idx, num_global_experts),
                                    (prepend_axis_num + 1 + tp_axis, tp_rank, tp_size * 2),
                                    replica_id=replica_id,
                                    prepend_axis_num=prepend_axis_num,
                                    flattened_range=slice(
                                        w_start_range, w_start_range + sum(w_lens)
                                    ),
                                ),
                                'w_lens': LocalNonpersistentObject(w_lens),
                                'v_tensors': ShardedTensor.from_rank_offsets_flat(
                                    key,
                                    (
                                        torch.cat(v_tensors, -1)
                                        if len(v_tensors) > 0
                                        else torch.Tensor()
                                    ),
                                    non_flat_local_shape,
                                    *sharded_offsets,
                                    (prepend_axis_num, expert_global_idx, num_global_experts),
                                    (
                                        prepend_axis_num + 1 + tp_axis,
                                        tp_rank + tp_size,
                                        tp_size * 2,
                                    ),
                                    replica_id=replica_id,
                                    prepend_axis_num=prepend_axis_num,
                                    flattened_range=slice(
                                        v_start_range, v_start_range + sum(v_lens)
                                    ),
                                ),
                                'v_lens': LocalNonpersistentObject(v_lens),
                                'first_glu_idx': LocalNonpersistentObject(first_glu_idx),
                            }
                        )
                else:
                    non_flat_local_shape = (
                        real_shape[0] // self.num_local_experts,
                        *real_shape[1:],
                    )
                    chunk_numel = local_ffn_dim_size * self.config.hidden_size
                    sub_states = []
                    start_pos = 0
                    for local_expert_idx in range(self.num_local_experts):
                        if (
                            flattened_range.start < chunk_numel * (local_expert_idx + 1)
                            and flattened_range.stop > chunk_numel * local_expert_idx
                        ):
                            end_pos = min(
                                flattened_range.stop,
                                chunk_numel * (local_expert_idx + 1) - flattened_range.start,
                            )
                            local_tensor = t[start_pos:end_pos]
                            local_flattened_range = slice(
                                max(0, flattened_range.start - chunk_numel * local_expert_idx),
                                min(
                                    chunk_numel,
                                    flattened_range.stop - chunk_numel * local_expert_idx,
                                ),
                            )
                            assert (
                                len(local_tensor)
                                == local_flattened_range.stop - local_flattened_range.start
                            )
                            start_pos += len(local_tensor)
                            expert_global_idx = local_expert_indices_offset + local_expert_idx
                            sub_states.append(
                                ShardedTensor.from_rank_offsets_flat(
                                    key,
                                    local_tensor,
                                    non_flat_local_shape,
                                    *sharded_offsets,
                                    (prepend_axis_num, expert_global_idx, num_global_experts),
                                    (prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
                                    replica_id=replica_id,
                                    prepend_axis_num=prepend_axis_num,
                                    flattened_range=local_flattened_range,
                                )
                            )
            return sub_states

        @torch.no_grad()
        def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
            if tp_axis == 1:
                # weight1
                weight_shape = (self.config.hidden_size, -1)
            elif tp_axis == 0:
                # weight2
                weight_shape = (-1, self.config.hidden_size)
                assert with_glu == False
            else:
                raise ValueError("tp_axis should be 0 or 1.")
            if isinstance(sub_state_dict, list) and isinstance(sub_state_dict[0], dict):
                # flattened tensor with glu
                res = []
                for local_expert_dict in sub_state_dict:
                    w_tensors = torch.split(
                        local_expert_dict['w_tensors'], local_expert_dict['w_lens']
                    )
                    v_tensors = torch.split(
                        local_expert_dict['v_tensors'], local_expert_dict['v_lens']
                    )
                    first_glu_idx = local_expert_dict['first_glu_idx']
                    if first_glu_idx == 0:
                        res += [
                            x for x in itertools.chain(*itertools.zip_longest(w_tensors, v_tensors))
                        ]
                    else:
                        res += [
                            x for x in itertools.chain(*itertools.zip_longest(v_tensors, w_tensors))
                        ]
                return torch.cat(res)
            elif isinstance(sub_state_dict, list) and sub_state_dict[0].ndim == 1:
                # flattened tensor without glu
                return torch.cat(sub_state_dict)
            else:
                if with_glu:
                    sub_state_dict = torch.cat(sub_state_dict, -2)
                return sub_state_dict.transpose(-1, -2).reshape(weight_shape)

        state_dict = self.state_dict(prefix='', keep_vars=True)
        for name, tensor in state_dict.items():
            if name == 'weight1':
                tp_axis = 1
                with_glu = self.config.gated_linear_unit
                wkey = f'{prefix}experts.linear_fc1.weight'
            else:
                tp_axis = 0
                with_glu = False
                wkey = f'{prefix}experts.linear_fc2.weight'
            sharded_state_dict[f'{prefix}{name}'] = ShardedTensorFactory(
                wkey,
                tensor,
                partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
                partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
                replica_id,
            )

        replica_id = (
            0,
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
        )
        # Add fake _extra_state to be compatible with SequentialMLP
        for expert_local_idx in range(self.num_local_experts):
            expert_global_idx = local_expert_indices_offset + expert_local_idx
            expert_sharded_offsets = (
                *sharded_offsets,
                (len(sharded_offsets), expert_global_idx, num_global_experts),
            )
            for mod in ['linear_fc1', 'linear_fc2']:
                sharded_state_dict[f'{prefix}expert{expert_global_idx}.{mod}._extra_state'] = (
                    make_sharded_object_for_checkpoint(
                        None,
                        f'{prefix}experts.{mod}._extra_state',
                        expert_sharded_offsets,
                        replica_id,
                    )
                )

        return sharded_state_dict