in modules/SwissArmyTransformer/sat/mpu/layers.py [0:0]
def partition(self, new_model_parallel_size=None, full_weight=None):
assert self.output_size_per_partition == self.output_size or full_weight is not None
flag = 1
if full_weight is None:
full_weight = self.weight
flag = 2
if new_model_parallel_size is None:
new_model_parallel_size = get_model_parallel_world_size()
output_size_per_partition = divide(self.output_size, new_model_parallel_size)
new_weights = []
new_biases = []
mp_size = new_model_parallel_size
# weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts
strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios
assert full_weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly'
factor = full_weight.shape[0] // sum(strides)
# decompose weight according to strides
strided_weights, _acm = [], 0
for i in range(len(strides)):
strided_weights.append(full_weight[_acm:_acm+factor*strides[i], :].detach())
_acm += factor*strides[i]
if flag == 2 and self.bias is not None and self.bias.numel() != 0:
# decompose bias according to strides
strided_biases, _acm = [], 0
for i in range(len(strides)):
strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach())
_acm += factor*strides[i]
for rank in range(new_model_parallel_size):
mp_rank = rank
new_weight = torch.cat([
strided_weight[
(strided_weight.shape[0]//mp_size)*mp_rank:
(strided_weight.shape[0]//mp_size)*(mp_rank+1)
]
for strided_weight in strided_weights
], dim=0).contiguous().view(output_size_per_partition, self.input_size)
new_weights.append(torch.clone(new_weight).detach())
if flag == 2 and self.bias is not None and self.bias.numel() != 0:
new_bias = torch.cat([
strided_bias[
(strided_bias.shape[0]//mp_size)*mp_rank:
(strided_bias.shape[0]//mp_size)*(mp_rank+1)
]
for strided_bias in strided_biases
], dim=0).contiguous().view(output_size_per_partition)
new_biases.append(torch.clone(new_bias).detach())
if flag == 1:
return new_weights
else:
return new_weights, new_biases