def matmul()

in crypten/cuda/cuda_tensor.py [0:0]


    def matmul(x, y, *args, **kwargs):
        # Use 4 blocks if each dot product is 256 elements or larger to prevent overflow in the sum
        nb = 3 if x.size(-1) < 256 else 4

        # Prepend 1 to the dimension of x or y if it is 1-dimensional
        remove_x, remove_y = False, False
        if x.dim() == 1:
            x = x.view(1, x.shape[0])
            remove_x = True
        if y.dim() == 1:
            y = y.view(y.shape[0], 1)
            remove_y = True

        x_encoded = CUDALongTensor.__encode_as_fp64(x, nb).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y, nb).data

        # Span x and y for cross multiplication
        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(nb, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=nb, dim=0)

        # Broadcasting
        for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)):
            if x_enc_span.ndim > y_enc_span.ndim:
                y_enc_span.unsqueeze_(1)
            else:
                x_enc_span.unsqueeze_(1)

        z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs)

        if remove_x:
            z_encoded.squeeze_(-2)
        if remove_y:
            z_encoded.squeeze_(-1)

        return CUDALongTensor.__decode_as_int64(z_encoded, nb)