DetectedWeightType DetectWeightType()

in libraries/passes/src/DetectLowPrecisionConvolutionTransformation.cpp [64:160]


        DetectedWeightType DetectWeightType(const nodes::ConvolutionalLayerNode<ValueType>* node, model::ModelTransformer& transformer)
        {
            DetectedWeightType detectedWeightType{ DetectedWeightType::Unknown };
            auto layer = node->GetLayer();
            auto weights = node->GetLayer().GetWeights();

            const int numFilters = layer.GetOutputShape().NumChannels();

            // Skip convolutions that have only a single input channel e.g. in the spatial portion
            // of depthwise separable convolutions
            if (weights.NumChannels() == 1)
            {
                return DetectedWeightType::FullPrecision;
            }

            // Perform detection based on each filter, since for xnor, each filter is mean or -mean.
            // Later, we confirm that each filter detects the same type.
            for (int filter = 0; filter < numFilters; ++filter)
            {
                // Keep track of unique values for this filter
                std::set<ValueType> uniqueValues;
                const size_t numChannels = weights.NumChannels();
                const size_t numRows = layer.GetConvolutionalParameters().receptiveField;

                for (size_t channel = 0; channel < numChannels; ++channel)
                {
                    for (size_t row = 0; row < numRows; ++row)
                    {
                        // Filters are stacked in the row dimension
                        const int weightRow = row + (filter * numRows);

                        for (size_t column = 0; column < layer.GetConvolutionalParameters().receptiveField; ++column)
                        {
                            uniqueValues.insert(weights(weightRow, column, channel));
                        }
                        if (uniqueValues.size() > 3)
                        {
                            break;
                        }
                    }
                }

                // Binary or ternary weights should only have 2 or 3 unique values respectively
                if (uniqueValues.size() < 2 || uniqueValues.size() > 3)
                {
                    return DetectedWeightType::FullPrecision;
                }

                DetectedWeightType proposedWeightType{ DetectedWeightType::FullPrecision };

                // Retrieve unique values as a vector
                std::vector<ValueType> values(uniqueValues.begin(), uniqueValues.end());
                if (values.size() == 2) // If the number of unique values is 2, check for binary or XNOR
                {
                    const ValueType x = values[0];
                    const ValueType y = values[1];

                    // If the number of uniques values is 2, check for binary or XNor.
                    // Binary is -1 or 1, XNor is -mean, mean
                    if (std::abs(x) == 1 && (x == -y))
                    {
                        // Values are 1 and -1
                        proposedWeightType = DetectedWeightType::Binary;
                    }
                    else if (x == -y)
                    {
                        // Values are mean and -mean
                        proposedWeightType = DetectedWeightType::SignedMean;
                    }
                }
                else if (values.size() == 3) // If the number of uniques values is 3, check for ternary. Ternary values are -1, 0, 1.
                {
                    bool ternary = true;
                    for (const auto value : values)
                    {
                        if (value != 1 && value != -1 && value != 0) ternary = false;
                    }
                    if (ternary)
                    {
                        // Values are -1, 0, or 1
                        proposedWeightType = DetectedWeightType::Ternary;
                    }
                }

                // Verify that the poposed weight type this is the same as the previously detected weight type.
                if (detectedWeightType == DetectedWeightType::Unknown)
                {
                    detectedWeightType = proposedWeightType;
                }
                else if (proposedWeightType != detectedWeightType)
                {
                    return DetectedWeightType::FullPrecision;
                }
            }

            return detectedWeightType;
        }