def _quantize_nn_spec()

in coremltools/models/neural_network/quantization_utils.py [0:0]


def _quantize_nn_spec(nn_spec, nbits, qm, **kwargs):
    """ Quantize weights in NeuralNetwork type mlmodel specifications.
    """
    selector = kwargs.get("selector", QuantizedLayerSelector())

    if qm not in _SUPPORTED_QUANTIZATION_MODES:
        raise Exception("Quantization mode {} not supported".format(qm))

    if qm != _QUANTIZATION_MODE_DEQUANTIZE:
        if nbits is None:
            raise Exception('Missing argument "nbits"')
        if not (nbits > 0 and nbits <= 8 or nbits == 16):
            raise Exception(
                "Only half precision (16-bit), 1 to 8-bit " "quantization is supported"
            )

    if qm == _QUANTIZATION_MODE_LINEAR_SYMMETRIC and nbits != 8:
        raise Exception("Symmetric quantization is only applicable for 8 bit" "linear")

    layers = nn_spec.layers

    # Perform optimization step
    if nbits is not None and nbits < 16 and qm != _QUANTIZATION_MODE_DEQUANTIZE:
        print("Optimizing Neural Network before Quantization:")
        _optimize_nn(layers)
        print("Finished optimizing network. Quantizing neural network..")

    # Quantize each layer
    for layer in layers:
        layer_type = layer.WhichOneof("layer")
        if not selector.do_quantize(layer):
            continue
        print("Quantizing layer {} of type {}".format(layer.name, layer_type))

        # Convolution
        if layer_type == "convolution":
            output_channels = layer.convolution.outputChannels
            kernel_channels = layer.convolution.kernelChannels
            kernel_height = layer.convolution.kernelSize[0]
            kernel_width = layer.convolution.kernelSize[1]
            groups = layer.convolution.nGroups
            counts = output_channels * kernel_channels * kernel_height * kernel_width
            has_bias = layer.convolution.hasBias
            if layer.convolution.isDeconvolution:
                shape = (
                    kernel_channels,
                    int(output_channels / groups),
                    kernel_height,
                    kernel_width,
                )
                _quantize_wp_field(
                    layer.convolution.weights, nbits, qm, shape, axis=1, **kwargs
                )
            else:
                shape = (output_channels, kernel_channels, kernel_height, kernel_width)
                _quantize_wp_field(
                    layer.convolution.weights, nbits, qm, shape, **kwargs
                )

            if has_bias and selector.do_quantize(layer, weight_param="bias"):
                _quantize_wp_field(
                    layer.convolution.bias,
                    nbits,
                    qm,
                    shape=(output_channels,),
                    **kwargs
                )

        # Batchnorm
        elif layer_type == "batchnorm":
            nw = layer.batchnorm.channels
            _quantize_wp_field(layer.batchnorm.gamma, nbits, qm, shape=(nw,), **kwargs)
            _quantize_wp_field(layer.batchnorm.beta, nbits, qm, shape=(nw,), **kwargs)
            _quantize_wp_field(layer.batchnorm.mean, nbits, qm, shape=(nw,), **kwargs)
            _quantize_wp_field(
                layer.batchnorm.variance, nbits, qm, shape=(nw,), **kwargs
            )

        # InnerProduct
        elif layer_type == "innerProduct":
            output_channels = layer.innerProduct.outputChannels
            input_channels = layer.innerProduct.inputChannels
            _quantize_wp_field(
                layer.innerProduct.weights,
                nbits,
                qm,
                shape=(output_channels, input_channels),
                **kwargs
            )
            has_bias = layer.innerProduct.hasBias
            if has_bias and selector.do_quantize(layer, weight_param="bias"):
                _quantize_wp_field(
                    layer.innerProduct.bias,
                    nbits,
                    qm,
                    shape=(output_channels,),
                    **kwargs
                )

        # BatchedMatmul
        elif layer_type == "batchedMatmul":
            x1 = layer.batchedMatmul.weightMatrixFirstDimension
            x2 = layer.batchedMatmul.weightMatrixSecondDimension
            _quantize_wp_field(
                layer.batchedMatmul.weights, nbits, qm, shape=(x2, x1), **kwargs
            )
            has_bias = layer.batchedMatmul.hasBias
            if has_bias and selector.do_quantize(layer, weight_param="bias"):
                _quantize_wp_field(
                    layer.batchedMatmul.bias, nbits, qm, shape=(x2,), **kwargs
                )

        # Embedding layer
        elif layer_type == "embedding":
            output_channels = layer.embedding.outputChannels
            input_channels = layer.embedding.inputDim
            _quantize_wp_field(
                layer.embedding.weights,
                nbits,
                qm,
                shape=(output_channels, input_channels),
                **kwargs
            )
            if layer.embedding.hasBias:
                _quantize_wp_field(
                    layer.embedding.bias, nbits, qm, shape=(output_channels,), **kwargs
                )

        # Embedding ND layer
        elif layer_type == "embeddingND":
            output_channels = layer.embeddingND.embeddingSize
            input_channels = layer.embeddingND.vocabSize
            _quantize_wp_field(
                layer.embeddingND.weights,
                nbits,
                qm,
                shape=(output_channels, input_channels),
                **kwargs
            )
            if layer.embeddingND.hasBias:
                _quantize_wp_field(
                    layer.embeddingND.bias,
                    nbits,
                    qm,
                    shape=(output_channels,),
                    **kwargs
                )

        # Scale layer
        elif layer_type == "scale":
            nw = _np.prod(layer.scale.shapeScale)
            _quantize_wp_field(layer.scale.scale, nbits, qm, shape=(nw,), **kwargs)
            if layer.scale.hasBias:
                nw = _np.prod(layer.scale.shapeBias)
                _quantize_wp_field(layer.scale.bias, nbits, qm, shape=(nw,), **kwargs)

        # Bias layer
        elif layer_type == "bias":
            nw = _np.prod(layer.bias.shape)
            _quantize_wp_field(layer.bias.bias, nbits, qm, shape=(nw,), **kwargs)

        # LoadConstant layer
        elif layer_type == "loadConstant":
            nw = _np.prod(layer.loadConstant.shape)
            _quantize_wp_field(
                layer.loadConstant.data, nbits, qm, shape=(nw,), **kwargs
            )

        # Simple Recurrent
        elif layer_type == "simpleRecurrent":
            i_size = layer.simpleRecurrent.inputVectorSize
            o_size = layer.simpleRecurrent.outputVectorSize
            _quantize_wp_field(
                layer.simpleRecurrent.weightMatrix,
                nbits,
                qm,
                shape=(o_size, i_size),
                **kwargs
            )
            _quantize_wp_field(
                layer.simpleRecurrent.recursionMatrix,
                nbits,
                qm,
                shape=(o_size, o_size),
                **kwargs
            )
            if layer.simpleRecurrent.hasBiasVector:
                _quantize_wp_field(
                    layer.simpleRecurrent.biasVector,
                    nbits,
                    qm,
                    shape=(o_size,),
                    **kwargs
                )

        # GRU
        elif layer_type == "gru":
            i_size = layer.gru.inputVectorSize
            o_size = layer.gru.outputVectorSize
            # Weight Matrix
            _quantize_wp_field(
                layer.gru.updateGateWeightMatrix,
                nbits,
                qm,
                shape=(o_size, i_size),
                **kwargs
            )
            _quantize_wp_field(
                layer.gru.resetGateWeightMatrix,
                nbits,
                qm,
                shape=(o_size, i_size),
                **kwargs
            )
            _quantize_wp_field(
                layer.gru.outputGateWeightMatrix,
                nbits,
                qm,
                shape=(o_size, i_size),
                **kwargs
            )
            # Recursion Weights
            _quantize_wp_field(
                layer.gru.updateGateRecursionMatrix,
                nbits,
                qm,
                shape=(o_size, o_size),
                **kwargs
            )
            _quantize_wp_field(
                layer.gru.resetGateRecursionMatrix,
                nbits,
                qm,
                shape=(o_size, o_size),
                **kwargs
            )
            _quantize_wp_field(
                layer.gru.outputGateRecursionMatrix,
                nbits,
                qm,
                shape=(o_size, o_size),
                **kwargs
            )
            # Bias
            if layer.gru.hasBiasVectors:
                _quantize_wp_field(
                    layer.gru.updateGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )
                _quantize_wp_field(
                    layer.gru.resetGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )
                _quantize_wp_field(
                    layer.gru.outputGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )

        # LSTM Layers
        elif layer_type in ["uniDirectionalLSTM", "biDirectionalLSTM"]:

            def _lstmwp_to_fp16_lstmwp(
                lstm_wp, nbits, qm, i_size, o_size, has_peephole=True
            ):
                assert lstm_wp
                _quantize_wp_field(
                    lstm_wp.inputGateWeightMatrix,
                    nbits,
                    qm,
                    shape=(o_size, i_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.forgetGateWeightMatrix,
                    nbits,
                    qm,
                    shape=(o_size, i_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.blockInputWeightMatrix,
                    nbits,
                    qm,
                    shape=(o_size, i_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.outputGateWeightMatrix,
                    nbits,
                    qm,
                    shape=(o_size, i_size),
                    **kwargs
                )

                _quantize_wp_field(
                    lstm_wp.inputGateRecursionMatrix,
                    nbits,
                    qm,
                    shape=(o_size, o_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.forgetGateRecursionMatrix,
                    nbits,
                    qm,
                    shape=(o_size, o_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.blockInputRecursionMatrix,
                    nbits,
                    qm,
                    shape=(o_size, o_size),
                    **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.outputGateRecursionMatrix,
                    nbits,
                    qm,
                    shape=(o_size, o_size),
                    **kwargs
                )

                _quantize_wp_field(
                    lstm_wp.inputGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.forgetGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.blockInputBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )
                _quantize_wp_field(
                    lstm_wp.outputGateBiasVector, nbits, qm, shape=(o_size,), **kwargs
                )

                if has_peephole:
                    _quantize_wp_field(
                        lstm_wp.inputGatePeepholeVector,
                        nbits,
                        qm,
                        shape=(o_size,),
                        **kwargs
                    )
                    _quantize_wp_field(
                        lstm_wp.forgetGatePeepholeVector,
                        nbits,
                        qm,
                        shape=(o_size,),
                        **kwargs
                    )
                    _quantize_wp_field(
                        lstm_wp.outputGatePeepholeVector,
                        nbits,
                        qm,
                        shape=(o_size,),
                        **kwargs
                    )

            if layer_type == "uniDirectionalLSTM":
                _lstmwp_to_fp16_lstmwp(
                    lstm_wp=layer.uniDirectionalLSTM.weightParams,
                    nbits=nbits,
                    qm=qm,
                    i_size=layer.uniDirectionalLSTM.inputVectorSize,
                    o_size=layer.uniDirectionalLSTM.outputVectorSize,
                    has_peephole=layer.uniDirectionalLSTM.params.hasPeepholeVectors,
                )

            elif layer_type == "biDirectionalLSTM":
                for lstm_wp in layer.biDirectionalLSTM.weightParams:
                    _lstmwp_to_fp16_lstmwp(
                        lstm_wp=lstm_wp,
                        nbits=nbits,
                        qm=qm,
                        i_size=layer.biDirectionalLSTM.inputVectorSize,
                        o_size=layer.biDirectionalLSTM.outputVectorSize,
                        has_peephole=layer.biDirectionalLSTM.params.hasPeepholeVectors,
                    )

        elif layer_type == "custom":
            print(
                "Skipping custom layer {}. Weights for this layer need to"
                "be converted manually".format(layer.name)
            )
        elif layer_type == "branch":
            _quantize_nn_spec(layer.branch.ifBranch, nbits, qm, **kwargs)
            _quantize_nn_spec(layer.branch.elseBranch, nbits, qm, **kwargs)
        elif layer_type == "loop":
            _quantize_nn_spec(layer.loop.conditionNetwork, nbits, qm, **kwargs)
            _quantize_nn_spec(layer.loop.bodyNetwork, nbits, qm, **kwargs)
        else:
            raise Exception("Unknown layer " + layer_type + " to be quantized")