bool CoreML::hasWeightOfType()

in mlmodel/src/Utils.cpp [191:262]


bool CoreML::hasWeightOfType(const Specification::NeuralNetworkLayer& layer,
                             const WeightParamType& type) {

    switch (layer.layer_case()) {
        case Specification::NeuralNetworkLayer::LayerCase::kConvolution:
            return (isWeightParamOfType(layer.convolution().weights(),type) ||
                    isWeightParamOfType(layer.convolution().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kInnerProduct:
            return (isWeightParamOfType(layer.innerproduct().weights(),type) ||
                    isWeightParamOfType(layer.innerproduct().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kBatchedMatmul:
            return (isWeightParamOfType(layer.batchedmatmul().weights(),type) ||
                    isWeightParamOfType(layer.batchedmatmul().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kBatchnorm:
            return (isWeightParamOfType(layer.batchnorm().gamma(), type) ||
                    isWeightParamOfType(layer.batchnorm().beta(), type) ||
                    isWeightParamOfType(layer.batchnorm().mean(), type) ||
                    isWeightParamOfType(layer.batchnorm().variance(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kLoadConstant:
            return isWeightParamOfType(layer.loadconstant().data(), type);

        case Specification::NeuralNetworkLayer::LayerCase::kScale:
            return (isWeightParamOfType(layer.scale().scale(), type) ||
                    isWeightParamOfType(layer.scale().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kSimpleRecurrent:
            return (isWeightParamOfType(layer.simplerecurrent().weightmatrix(), type) ||
                    isWeightParamOfType(layer.simplerecurrent().recursionmatrix(), type) ||
                    isWeightParamOfType(layer.simplerecurrent().biasvector(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kGru:
            return (isWeightParamOfType(layer.gru().updategateweightmatrix(), type) ||
                    isWeightParamOfType(layer.gru().resetgateweightmatrix(), type) ||
                    isWeightParamOfType(layer.gru().outputgateweightmatrix(), type) ||
                    isWeightParamOfType(layer.gru().updategaterecursionmatrix(), type) ||
                    isWeightParamOfType(layer.gru().resetgaterecursionmatrix(), type) ||
                    isWeightParamOfType(layer.gru().outputgaterecursionmatrix(), type) ||
                    isWeightParamOfType(layer.gru().updategatebiasvector(), type) ||
                    isWeightParamOfType(layer.gru().resetgatebiasvector(), type) ||
                    isWeightParamOfType(layer.gru().outputgatebiasvector(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kEmbedding:
            return (isWeightParamOfType(layer.embedding().weights(), type) ||
                    isWeightParamOfType(layer.embedding().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kEmbeddingND:
            return (isWeightParamOfType(layer.embeddingnd().weights(), type) ||
                    isWeightParamOfType(layer.embeddingnd().bias(), type));

        case Specification::NeuralNetworkLayer::LayerCase::kUniDirectionalLSTM:
            return hasLSTMWeightParamOfType(layer.unidirectionallstm().weightparams(), type);

        case Specification::NeuralNetworkLayer::LayerCase::kBiDirectionalLSTM:
            return (hasLSTMWeightParamOfType(layer.bidirectionallstm().weightparams(0), type) ||
                    hasLSTMWeightParamOfType(layer.bidirectionallstm().weightparams(1), type));

        case Specification::NeuralNetworkLayer::LayerCase::kActivation:
            if(layer.activation().NonlinearityType_case() == Specification::ActivationParams::NonlinearityTypeCase::kPReLU) {
                return isWeightParamOfType(layer.activation().prelu().alpha(), type);
            } else if(layer.activation().NonlinearityType_case() == Specification::ActivationParams::NonlinearityTypeCase::kParametricSoftplus) {
                return (isWeightParamOfType(layer.activation().parametricsoftplus().alpha(), type) ||
                        isWeightParamOfType(layer.activation().parametricsoftplus().beta(), type));
            }
        default:
            break;
    }
    return false;
}