in megatron_patch/model/deepseek_v2/moe/experts.py [0:0]
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
Maps local expert to global experts.
The sharded_state_dict for the weight parts are compatible with the SequentialMLP,
whereas the optimizer states are not due to the limitation from weight transposing.
That is, for finetuning scenario, the checkpoint is compatible with the SequentialMLP.
"""
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)
sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
prepend_axis_num = len(sharded_offsets)
replica_id = (
0,
0,
parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
)
local_ffn_dim_size = (
self.weight2.numel() // self.num_local_experts // self.config.hidden_size
)
@torch.no_grad()
def sh_ten_build_fn(
key: str,
t: torch.Tensor,
replica_id: ReplicaId,
flattened_range: Optional[slice],
tp_axis: int,
with_glu: bool,
):
# TODO: write a generic implementation to cover both cases with and without GLU
if tp_axis == 1:
# weight1
if with_glu:
last_dim_size = local_ffn_dim_size * 2
else:
last_dim_size = local_ffn_dim_size
real_shape = (self.num_local_experts, self.config.hidden_size, last_dim_size)
elif tp_axis == 0:
# weight2
real_shape = (self.num_local_experts, local_ffn_dim_size, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if flattened_range is None:
# weights
t = t.view(real_shape).transpose(-1, -2)
# change tp_axis due to the transposing
tp_axis = 1 - tp_axis
if with_glu:
local_tensors = torch.chunk(t, 2, -2)
sub_states = [
ShardedTensor.from_rank_offsets(
key,
local_tensors[0].contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1, tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
local_tensors[1].contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1, tp_size + tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
sub_states = ShardedTensor.from_rank_offsets(
key,
t.contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
)
else:
# flattened optmizer states
# the non-flattened weight shape is [local_expert_num, hidden_size, ffn_size]
#
# For the case without GLU, it is straightforward, we just need to split each
# expert along the dim-0.
#
# For the case with GLU, we need to split the experts along dim-0 and split the
# two tensors for GLU along dim-2.
# To split along the non-first dim, we need to chunk the tensor into small pieces,
# since they belong to different tenors and are interleaved in the flattened space.
# Refer to the below sketch graph.
# |................| |........|........|
# |............FFFF| |........|....BBBB|
# |FFFFFFFFFFFFFFFF| -> |AAAAAAAA|BBBBBBBB|
# |FFFFFFFFFFFFFFFF| |AAAAAAAA|BBBBBBBB|
# |FF..............| |AA......|........|
# |................| |........|........|
#
# But too many chunks have severe performance issues. We merge these chunks during
# the save process along with some length information and recover them during the
# load process.
assert t.ndim == 1, (key, t.shape)
if with_glu:
non_flat_local_shape = (1, self.config.hidden_size, local_ffn_dim_size)
chunk_numel = local_ffn_dim_size
sub_states = []
start_pos = 0
for local_expert_idx in range(self.num_local_experts):
first_glu_idx = -1
w_start_range = -1
v_start_range = -1
w_tensors = []
v_tensors = []
w_lens = []
v_lens = []
for input_dim_idx in range(self.config.hidden_size):
for glu_idx in range(2):
local_idx = (
local_expert_idx * self.config.hidden_size * 2
+ input_dim_idx * 2
+ glu_idx
)
if (
flattened_range.start < chunk_numel * (local_idx + 1)
and flattened_range.stop > chunk_numel * local_idx
):
if first_glu_idx == -1:
first_glu_idx = glu_idx
end_pos = min(
flattened_range.stop,
chunk_numel * (local_idx + 1) - flattened_range.start,
)
local_tensor = t[start_pos:end_pos]
local_flattened_range = slice(
max(0, flattened_range.start - chunk_numel * local_idx),
min(
chunk_numel,
flattened_range.stop - chunk_numel * local_idx,
),
)
assert (
len(local_tensor)
== local_flattened_range.stop - local_flattened_range.start
)
start_pos += len(local_tensor)
expert_global_idx = (
local_expert_indices_offset + local_expert_idx
)
if glu_idx == 0:
w_tensors.append(local_tensor)
w_lens.append(len(local_tensor))
if w_start_range == -1:
w_start_range = max(
0, flattened_range.start - chunk_numel * local_idx
)
else:
v_tensors.append(local_tensor)
v_lens.append(len(local_tensor))
if v_start_range == -1:
v_start_range = max(
0, flattened_range.start - chunk_numel * local_idx
)
sub_states.append(
{
'w_tensors': ShardedTensor.from_rank_offsets_flat(
key,
(
torch.cat(w_tensors, -1)
if len(w_tensors) > 0
else torch.Tensor()
),
non_flat_local_shape,
*sharded_offsets,
(prepend_axis_num, expert_global_idx, num_global_experts),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=slice(
w_start_range, w_start_range + sum(w_lens)
),
),
'w_lens': LocalNonpersistentObject(w_lens),
'v_tensors': ShardedTensor.from_rank_offsets_flat(
key,
(
torch.cat(v_tensors, -1)
if len(v_tensors) > 0
else torch.Tensor()
),
non_flat_local_shape,
*sharded_offsets,
(prepend_axis_num, expert_global_idx, num_global_experts),
(
prepend_axis_num + 1 + tp_axis,
tp_rank + tp_size,
tp_size * 2,
),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=slice(
v_start_range, v_start_range + sum(v_lens)
),
),
'v_lens': LocalNonpersistentObject(v_lens),
'first_glu_idx': LocalNonpersistentObject(first_glu_idx),
}
)
else:
non_flat_local_shape = (
real_shape[0] // self.num_local_experts,
*real_shape[1:],
)
chunk_numel = local_ffn_dim_size * self.config.hidden_size
sub_states = []
start_pos = 0
for local_expert_idx in range(self.num_local_experts):
if (
flattened_range.start < chunk_numel * (local_expert_idx + 1)
and flattened_range.stop > chunk_numel * local_expert_idx
):
end_pos = min(
flattened_range.stop,
chunk_numel * (local_expert_idx + 1) - flattened_range.start,
)
local_tensor = t[start_pos:end_pos]
local_flattened_range = slice(
max(0, flattened_range.start - chunk_numel * local_expert_idx),
min(
chunk_numel,
flattened_range.stop - chunk_numel * local_expert_idx,
),
)
assert (
len(local_tensor)
== local_flattened_range.stop - local_flattened_range.start
)
start_pos += len(local_tensor)
expert_global_idx = local_expert_indices_offset + local_expert_idx
sub_states.append(
ShardedTensor.from_rank_offsets_flat(
key,
local_tensor,
non_flat_local_shape,
*sharded_offsets,
(prepend_axis_num, expert_global_idx, num_global_experts),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=local_flattened_range,
)
)
return sub_states
@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
if tp_axis == 1:
# weight1
weight_shape = (self.config.hidden_size, -1)
elif tp_axis == 0:
# weight2
weight_shape = (-1, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if isinstance(sub_state_dict, list) and isinstance(sub_state_dict[0], dict):
# flattened tensor with glu
res = []
for local_expert_dict in sub_state_dict:
w_tensors = torch.split(
local_expert_dict['w_tensors'], local_expert_dict['w_lens']
)
v_tensors = torch.split(
local_expert_dict['v_tensors'], local_expert_dict['v_lens']
)
first_glu_idx = local_expert_dict['first_glu_idx']
if first_glu_idx == 0:
res += [
x for x in itertools.chain(*itertools.zip_longest(w_tensors, v_tensors))
]
else:
res += [
x for x in itertools.chain(*itertools.zip_longest(v_tensors, w_tensors))
]
return torch.cat(res)
elif isinstance(sub_state_dict, list) and sub_state_dict[0].ndim == 1:
# flattened tensor without glu
return torch.cat(sub_state_dict)
else:
if with_glu:
sub_state_dict = torch.cat(sub_state_dict, -2)
return sub_state_dict.transpose(-1, -2).reshape(weight_shape)
state_dict = self.state_dict(prefix='', keep_vars=True)
for name, tensor in state_dict.items():
if name == 'weight1':
tp_axis = 1
with_glu = self.config.gated_linear_unit
wkey = f'{prefix}experts.linear_fc1.weight'
else:
tp_axis = 0
with_glu = False
wkey = f'{prefix}experts.linear_fc2.weight'
sharded_state_dict[f'{prefix}{name}'] = ShardedTensorFactory(
wkey,
tensor,
partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
replica_id,
)
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
)
# Add fake _extra_state to be compatible with SequentialMLP
for expert_local_idx in range(self.num_local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
for mod in ['linear_fc1', 'linear_fc2']:
sharded_state_dict[f'{prefix}expert{expert_global_idx}.{mod}._extra_state'] = (
make_sharded_object_for_checkpoint(
None,
f'{prefix}experts.{mod}._extra_state',
expert_sharded_offsets,
replica_id,
)
)
return sharded_state_dict