def partition()

in modules/SwissArmyTransformer/sat/mpu/layers.py [0:0]


    def partition(self, new_model_parallel_size=None, full_weight=None):
        assert self.output_size_per_partition == self.output_size or full_weight is not None
        flag = 1
        if full_weight is None:
            full_weight = self.weight
            flag = 2
        if new_model_parallel_size is None:
            new_model_parallel_size = get_model_parallel_world_size()
        output_size_per_partition = divide(self.output_size, new_model_parallel_size)
        new_weights = []
        new_biases = []

        mp_size = new_model_parallel_size
        # weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts
        strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios
        assert full_weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly'
        factor = full_weight.shape[0] // sum(strides)
        # decompose weight according to strides
        strided_weights, _acm = [], 0
        for i in range(len(strides)):
            strided_weights.append(full_weight[_acm:_acm+factor*strides[i], :].detach())
            _acm += factor*strides[i]

        if flag == 2 and self.bias is not None and self.bias.numel() != 0:
            # decompose bias according to strides
            strided_biases, _acm = [], 0
            for i in range(len(strides)):
                strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach())
                _acm += factor*strides[i]

        for rank in range(new_model_parallel_size):
            mp_rank = rank
            new_weight = torch.cat([
                strided_weight[
                    (strided_weight.shape[0]//mp_size)*mp_rank:
                    (strided_weight.shape[0]//mp_size)*(mp_rank+1)
                    ]
                for strided_weight in strided_weights
            ], dim=0).contiguous().view(output_size_per_partition, self.input_size)
            new_weights.append(torch.clone(new_weight).detach())
            if flag == 2 and self.bias is not None and self.bias.numel() != 0:
                new_bias = torch.cat([
                    strided_bias[
                        (strided_bias.shape[0]//mp_size)*mp_rank:
                        (strided_bias.shape[0]//mp_size)*(mp_rank+1)
                        ]
                    for strided_bias in strided_biases
                ], dim=0).contiguous().view(output_size_per_partition)
                new_biases.append(torch.clone(new_bias).detach())
        if flag == 1:
            return new_weights
        else:
            return new_weights, new_biases