source/backend/cpu/CPUSoftmax.cpp (313 lines of code) (raw):

// // CPUSoftmax.cpp // MNN // // Created by MNN on 2018/07/16. // Copyright © 2018, Alibaba Group Holding Limited // #include <math.h> #include "backend/cpu/CPUSoftmax.hpp" #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/CommonOptFunction.h" #include "core/Concurrency.h" #include "core/Macro.h" #include "core/TensorUtils.hpp" #include "CPUTensorConvert.hpp" #include "CPUCast.hpp" namespace MNN { static void ___MNNSoftmax(float* dest, const float* source, size_t size, MNNBinaryExecute mulfunction) { float exprOffset[4] = { 1.0f, 0.0f, 0.0f, 0.0f }; // Compute Max { int32_t inputCountUnit = size / (4 * 2); int32_t remain = size - (inputCountUnit * 4 * 2); float Max = source[0]; if (inputCountUnit > 0) { float maxArray[4] = {Max, Max, Max, Max}; MNNMaxFloat((float*)source, maxArray, inputCountUnit); for (int i = 0; i < 4; i++) { Max = ALIMAX(Max, maxArray[i]); } } if (remain > 0) { int currentIndex = inputCountUnit * 4 * 2; for (int i = 0; i < remain; i++) { float currentInputData = source[currentIndex + i]; Max = ALIMAX(Max, currentInputData); } } exprOffset[2] = -Max; } MNNExp(dest, source, exprOffset, size); float sumDiv = 1.0f / exprOffset[3]; mulfunction(dest, dest, &sumDiv, size, 1); } int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { auto cpuBn = static_cast<CPUBackend*>(backend()); auto core = cpuBn->functions(); auto fp32Core = core; if (core->bytes != 4) { fp32Core = MNNGetCoreFunctions(); } MNNBinaryExecute addFunction; MNNUnaryExecute recFunction; MNNBinaryExecute mulFunction; mulFunction = fp32Core->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); auto bytes = core->bytes; int threadNumber = ALIMIN(cpuBn->threadNumber(), mOutside); int outsideStride = mChannel * mInside; if (mInside > core->pack && mChannel < core->pack) { auto maxFunction = core->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MAXIMUM); auto subFunction = core->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); addFunction = fp32Core->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); recFunction = fp32Core->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_RECIPROCAL, 1);//Use high precision MNN_CONCURRENCY_BEGIN(tId, threadNumber) { float* tempOutput = nullptr; float* tempInput = nullptr; if (mTmpInput.ptr()) { tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); } if (mTmpOutput.ptr()) { tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); } for (int o=tId; o<mOutside; o+=threadNumber) { auto srcO = srcData + o * outsideStride * mLowOrInt8; auto dstO = dstData + o * outsideStride * mLowOrInt8; // Max if (mLowOrInt8 == 1) { CPUCastCreator::cast(srcO, tempInput, CPUCastCreator::INT8_TO_FlOAT, outsideStride, mInQuantAttr->scale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); ::memcpy(tempOutput, tempInput, mInside * 4); for (int z = 1; z < mChannel; ++z) { maxFunction(tempOutput, tempOutput, tempInput + z * mInside, mInside, -1); } } else { ::memcpy(tempInput, srcO, mInside * mLowOrInt8); for (int z = 1; z < mChannel; ++z) { maxFunction(tempInput, tempInput, srcO + z * mInside * mLowOrInt8, mInside, -1); } } // Sub Max for (int z=0; z<mChannel; ++z) { if (mLowOrInt8 == 1) { subFunction(tempInput + z * mInside, tempInput + z * mInside, tempOutput, mInside, -1); } else { subFunction(dstO + z * mInside * mLowOrInt8, srcO + z * mInside * mLowOrInt8, tempInput, mInside, -1); } } // Exp float exprOffset[4] = { 1.0f, 0.0f, 0.0f, 0.0f }; auto workSrc = (float*)srcO; auto workDst = (float*)dstO; if (mLowOrInt8 != 4) { workSrc = tempInput; workDst = tempOutput; if (mLowOrInt8 == 2) { core->MNNLowpToFp32((int16_t*)(dstO), workSrc, outsideStride); } } // Use Fp32 to compute Begin MNNExp(workDst, workSrc, exprOffset, outsideStride); // Sum to tempInput ::memcpy(tempInput, workDst, mInside * sizeof(float)); for (int z=1; z<mChannel; ++z) { addFunction(tempInput, tempInput, workDst + z * mInside, mInside, -1); } recFunction(tempInput, tempInput, mInside); for (int z=0; z<mChannel; ++z) { mulFunction(workDst + z * mInside, workDst + z * mInside, tempInput, mInside, -1); } // Use Fp32 Compute end if (mLowOrInt8 == 2) { core->MNNFp32ToLowp(workDst, (int16_t*)(dstO), outsideStride); } else if (mLowOrInt8 == 1) { CPUCastCreator::cast(workDst, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); } else { // do nothing. } } }; MNN_CONCURRENCY_END(); return 0; } MNN_CONCURRENCY_BEGIN(tId, threadNumber) { float* tempInput; float* tempOutput; if (mTmpInput.ptr()) { tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); } if (mTmpOutput.ptr()) { tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); } for (int o=tId; o<mOutside; o+=threadNumber) { auto srcO = srcData + o * outsideStride * mLowOrInt8; auto dstO = dstData + o * outsideStride * mLowOrInt8; auto workSrc = (float*)srcO; auto workDst = (float*)dstO; // Pretreat if (1 == mInside) { if (mLowOrInt8 == 2) { core->MNNLowpToFp32((int16_t*)(srcO), tempInput, outsideStride); workDst = tempOutput; workSrc = tempInput; } else if (mLowOrInt8 == 1) { CPUCastCreator::cast(srcO, tempInput, CPUCastCreator::INT8_TO_FlOAT, outsideStride, mInQuantAttr->scale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); workDst = tempOutput; workSrc = tempInput; } } else { int dims[] = { mChannel, mInside, mInside, mChannel }; if (mLowOrInt8 == 2) { MNN_ASSERT(bytes == 2); MNNTranspose16Bit((int16_t*)tempOutput, (int16_t*)(srcO), dims); core->MNNLowpToFp32((int16_t*)tempOutput, tempInput, outsideStride); workDst = tempOutput; workSrc = tempInput; } else if (mLowOrInt8 == 1) { CPUCastCreator::cast(srcO, tempOutput, CPUCastCreator::INT8_TO_FlOAT, outsideStride, mInQuantAttr->scale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); MNNTranspose32Bit((int32_t*)tempInput, (int32_t*)tempOutput, dims); workDst = tempOutput; workSrc = tempInput; } else { // Use output to cache transpoe result MNNTranspose32Bit((int32_t*)dstO, (int32_t*)(srcO), dims); workDst = tempInput; workSrc = (float*)dstO; } } for (int v=0; v<mInside; ++v) { //TODO: Fix x86 compute error and use the same function #ifdef MNN_USE_SSE MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel); #else ___MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel, mulFunction); #endif } // PostTreat if (1 == mInside) { if (mLowOrInt8 == 2) { core->MNNFp32ToLowp(tempOutput, (int16_t*)(dstO), outsideStride); } else if (mLowOrInt8 == 1) { CPUCastCreator::cast(tempOutput, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); } } else { int dims[] = { mInside, mChannel, mChannel, mInside }; if (mLowOrInt8 == 2) { MNN_ASSERT(bytes == 2); core->MNNFp32ToLowp((float*)tempOutput, (int16_t*)tempInput, outsideStride); MNNTranspose16Bit((int16_t*)dstO, (int16_t*)(tempInput), dims); } else if (mLowOrInt8 == 1) { MNNTranspose32Bit((int32_t*)tempInput, (int32_t*)tempOutput, dims); CPUCastCreator::cast(tempInput, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); } else { MNNTranspose32Bit((int32_t*)dstO, (int32_t*)(tempInput), dims); } } } } MNN_CONCURRENCY_END(); return 0; } ErrorCode CPUSoftmax::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { auto input = inputs[0]; const int dimensions = input->buffer().dimensions; int axis = mAxis; if (axis < 0) { axis += dimensions; } const auto layout = TensorUtils::getDescribe(input)->dimensionFormat; mNeedUnpackC4 = layout == MNN_DATA_FORMAT_NC4HW4; if (mNeedUnpackC4) { int totalSize = 1; for (int i = 1; i < dimensions; ++i) { totalSize *= input->length(i); } mStorage.buffer().dim[0].extent = input->length(0); mStorage.buffer().dim[1].extent = totalSize; TensorUtils::getDescribe(&mStorage)->dimensionFormat = MNN_DATA_FORMAT_NHWC; mStorage.buffer().dimensions = 2; mStorage.buffer().type = input->getType(); backend()->onAcquireBuffer(&mStorage, Backend::DYNAMIC); } int inside = 1; int outside = 1; int channel = 1; for (int i = 0; i < axis; ++i) { outside *= input->length(i); } channel = input->length(axis); for (int i = axis + 1; i < dimensions; ++i) { inside *= input->length(i); } mInside = inside; mOutside = outside; mChannel = channel; mLowOrInt8 = 4; if (static_cast<CPUBackend*>(backend())->functions()->bytes != 4) { mLowOrInt8 = 2; } if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { mLowOrInt8 = 1; } mInQuantAttr = TensorUtils::getDescribe(inputs[0])->quantAttr; mOutQuantAttr = TensorUtils::getDescribe(outputs[0])->quantAttr; auto cpuBn = static_cast<CPUBackend*>(backend()); if (inside != 1 || mLowOrInt8 != 4) { // not run _softmax1, we need maxValue Tensor and sumValue Tensor. int threadNum = cpuBn->threadNumber(); auto buf = cpuBn->getBufferAllocator(); threadNum = ALIMIN(threadNum, outside); mTmpInput = buf->alloc(threadNum * inside * channel * sizeof(float)); if (mLowOrInt8 != 4) { mTmpOutput = buf->alloc(threadNum * inside * channel * sizeof(float)); buf->free(mTmpOutput); } buf->free(mTmpInput); } if (mNeedUnpackC4) { backend()->onReleaseBuffer(&mStorage, Backend::DYNAMIC); } return NO_ERROR; } ErrorCode CPUSoftmax::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { MNN_ASSERT(1 == inputs.size()); MNN_ASSERT(1 == outputs.size()); auto inputTensor = inputs[0]; auto outputTensor = outputs[0]; const auto inputDataPtr = inputTensor->host<float>(); auto outputDataPtr = outputTensor->host<float>(); const int batch = inputTensor->batch(); const auto dims = inputTensor->buffer().dimensions; float *tempData = nullptr; if (mNeedUnpackC4) { tempData = mStorage.host<float>(); } int areaInput = 1; for (int i = 2; i < dims; ++i) { areaInput *= inputTensor->length(i); } int threadNum = ((CPUBackend *)backend())->threadNumber(); if (!mNeedUnpackC4) { _softmaxCommon((uint8_t*)inputDataPtr, (uint8_t*)outputDataPtr); return NO_ERROR; } auto functions = static_cast<CPUBackend*>(backend())->functions(); CPUTensorConverter::convert(inputDataPtr, outputDataPtr, MNN_DATA_FORMAT_NC4HW4, MNN_DATA_FORMAT_NCHW, batch, areaInput, inputTensor->channel(), mLowOrInt8, functions); _softmaxCommon((uint8_t*)outputDataPtr, (uint8_t*)tempData); CPUTensorConverter::convert(tempData, outputDataPtr, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NC4HW4, batch, areaInput, inputTensor->channel(), mLowOrInt8, functions); return NO_ERROR; } CPUSoftmax::CPUSoftmax(Backend *b, int axis) : MNN::Execution(b), mAxis(axis), mStorage(2), mNeedUnpackC4(false) { // nothing to do } Execution* CPUSoftmax::create(const MNN::Op *op, Backend *backend) { auto axis = op->main_as_Axis()->axis(); return new CPUSoftmax(backend, axis); } class CPUSoftmaxCreator : public CPUBackend::Creator { public: virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, const MNN::Op *op, Backend *backend) const override { return CPUSoftmax::create(op, backend); } }; REGISTER_CPU_OP_CREATOR(CPUSoftmaxCreator, OpType_Softmax); } // namespace MNN