std::shared_ptr ConvolutionCommon::load()

in source/core/ConvolutionCommon.cpp [562:806]


std::shared_ptr<ConvolutionCommon::Int8Common> ConvolutionCommon::load(const Op* op, Backend* backend, bool forceFloat, bool forceInt8, void* weightPtr) {
    auto conv = op->main_as_Convolution2D();
    auto quan = conv->quanParameter();
    std::shared_ptr<ConvolutionCommon::Int8Common> result(new Int8Common);
    result->quan = quan;
    size_t buffer_size = 0, alpha_size = 0;
    const int8_t* buffer_ptr = nullptr;
    const float* alpha_ptr = nullptr;
    std::unique_ptr<int8_t[]> external_buffer;
    size_t weightLength = 0;
    int8_t *buffer        = nullptr;
    bool useCachedMmap = false;
    if (backend && backend->getRuntime()) {
        useCachedMmap = backend->getRuntime()->hint().useCachedMmap > 1;
    }
    if (USE_EXTERNAL_DATA(conv) && op->externalPath() && quan->type() == 8) {
        std::unique_ptr<FileLoader> external(new FileLoader(op->externalPath()->c_str()));
        auto param = op->main_as_Convolution2D();
        external->offset(param->external()->data()[0]);
        if(weightPtr != nullptr) {
            result->weightFloat.set((float *)weightPtr, false);
        } else {
            result->weightFloat.reset((int)(param->external()->data()[1] / sizeof(float)));
        }
        external->read((char*)(result->weightFloat.get()), param->external()->data()[1]);
        return result;
    }
    if (USE_EXTERNAL_DATA(conv) && (op->externalPath() || useCachedMmap) && quan->buffer() == nullptr) {
        auto external_info = conv->external()->data();
        buffer_size = external_info[1];
        alpha_size = external_info[2] / sizeof(float);
        result->alphaSize = alpha_size;
        
        if (useCachedMmap) {
            if (alpha_size) {
                weightLength = conv->common()->inputCount() * conv->common()->outputCount() * conv->common()->kernelX() * conv->common()->kernelY();
                int upperBound = 1;
                if (conv->common()->inputCount() > 65535 || conv->common()->outputCount() > 65535) { // 65535: max(uint16_t)
                    upperBound += 8; // shape dimension saved as type:int32_t
                } else {
                    upperBound += 4; // shape dimension saved as type:int16_t
                }
                upperBound += (UP_DIV(weightLength, 2) + 17); // 16(-8~7) + 1
                result->canUseInt4 = false;
                if (upperBound >= buffer_size) {
                    result->canUseInt4 = true;
                }
            }
        } else {
            // external data
            std::unique_ptr<FileLoader> external_file(new FileLoader(op->externalPath()->c_str()));
            external_file->offset(external_info[0]);
            if (0 != buffer_size) {
                if (1 == quan->type() && !forceFloat) {
                    buffer = IDSTDecoder::ReadQuanData_c(external_file.get(), &weightLength, result.get(), quan, forceInt8, forceFloat, weightPtr);
                } else {
                    external_buffer.reset(new int8_t[buffer_size]);
                    buffer_ptr = external_buffer.get();
                    external_file->read((char*)buffer_ptr, buffer_size);
                }
            }
            if (0 != alpha_size) {
                result->alpha.reset((int)alpha_size);
                if (nullptr == result->alpha.get()) {
                    MNN_PRINT("Alloc memory error for extract idst int8\n");
                    return nullptr;
                }
                alpha_ptr = result->alpha.get();
                external_file->read((char*)alpha_ptr, alpha_size * sizeof(float));
            }
        }
    } else {
        if (quan->buffer()) {
            buffer_size = quan->buffer()->size();
            buffer_ptr = quan->buffer()->data();
        }
        if (quan->alpha()) {
            alpha_size = quan->alpha()->size();
            alpha_ptr = quan->alpha()->data();
            result->alphaSize = alpha_size;
            result->alpha.reset((int)alpha_size);
            if (nullptr == result->alpha.get()) {
                MNN_PRINT("Alloc memory error for extract idst int8\n");
                return nullptr;
            }
            ::memcpy(result->alpha.get(), alpha_ptr, alpha_size * sizeof(float));
        }
    }
    if (quan->index() != nullptr) {
        if (forceFloat) {
            // Expand sparse to dense
            if(weightPtr != nullptr) {
                result->weightFloat.set((float *)weightPtr, false);
            } else {
                result->weightFloat.reset(quan->weightSize());
            }
            if (nullptr == result->weightFloat.get()) {
                return nullptr;
            }
            ::memset(result->weightFloat.get(), 0, quan->weightSize() * sizeof(float));
            auto index = quan->index()->data();
            auto indexSize = quan->index()->size();
            if (nullptr == alpha_ptr || alpha_size != indexSize) {
                MNN_ERROR("The model is error, don't has alpha but has index\n");
                return nullptr;
            }
            for (uint32_t i=0; i<indexSize; ++i) {
                result->weightFloat.get()[index[i]] = alpha_ptr[i];
            }
        } // Otherwise needn't treat, just return result with quan info
        return result;
    }

    std::unique_ptr<MemoryLoader> originBuffer(new MemoryLoader((unsigned char*)buffer_ptr));
    if (1 == quan->type() && weightLength == 0) {
        buffer = IDSTDecoder::ReadQuanData_c(originBuffer.get(), &weightLength, result.get(), quan, forceInt8, forceFloat, weightPtr);
    }
    if (2 == quan->type()) {
        buffer = IDSTDecoder::ReadSparseQuanData_c(originBuffer.get(), &weightLength, alpha_ptr, alpha_size, result.get(), quan, forceInt8, forceFloat, weightPtr);
    }
    // read fp16 data
    if (3 == quan->type()) {
        if (useCachedMmap) {
            weightLength = buffer_size / sizeof(half_float::half);
            if(weightPtr != nullptr) {
                result->weightFloat.set((float *)weightPtr, false);
            } else {
                result->weightFloat.reset((int)weightLength);
            }
            return result;
        }
        weightLength = buffer_size / sizeof(half_float::half);
        std::vector<int8_t> tempHalfWeight(buffer_size);
        ::memcpy(tempHalfWeight.data(), buffer_ptr, buffer_size);
        auto halfWeight = reinterpret_cast<half_float::half *>(tempHalfWeight.data());
        if(weightPtr != nullptr) {
            result->weightFloat.set((float *)weightPtr, false);
        } else {
            result->weightFloat.reset((int)weightLength);
        }
        if (nullptr == result->weightFloat.get()) {
            MNN_PRINT("Alloc memory error for extract fp16 back to float\n");
            return nullptr;
        }
        std::transform(halfWeight, halfWeight + weightLength, result->weightFloat.get(),
                       [](half_float::half h) { return float(h); });
        return result;
    }

    // weight int8 only
    if (4 == quan->type()) {
        weightLength = buffer_size;
        if(weightPtr != nullptr) {
            result->weight.set((int8_t *)weightPtr, false);
        } else {
            result->weight.reset((int)weightLength);
        }
        ::memcpy(result->weight.get(), buffer_ptr, weightLength);
    }

    bool oldType4 = (quan->type() == 4 && quan->aMin() == 0 && std::abs(quan->quantScale()) < 1e-6);
    if (quan->readType() != 0 || oldType4) {
        result->asymmetric = true;
    } else {
        result->asymmetric = false;
    }
    if (!useCachedMmap) {
        if (result->weight.get() == nullptr) {
            if (nullptr == buffer) {
                MNN_PRINT("Alloc memory error for extract idst int8\n");
                return nullptr;
            }
            if(weightPtr != nullptr) {
                result->weight.set(buffer, false);
            } else {
                result->weight.set(buffer, (int)weightLength);
            }
        }
        int outputCount = 0;
        if (result->asymmetric) {
            outputCount   = result->alpha.size() / 2;
            // clampMin is minVal in asymmetric quant, clampMin = -(2^(bit))
            // and old version clampMin is -128
            float clampMin = quan->aMin() == 0 ? -128 : quan->aMin();
            if (clampMin < 0) {
                for (int o = 0; o < outputCount; ++o) {
                    result->alpha.get()[2 * o] = result->alpha.get()[2 * o] - clampMin * result->alpha.get()[2 * o + 1];
                }
            }
        } else {
            outputCount   = result->alpha.size(); // backward compability with previous symmetric quantization
        }
        if (!quan->has_scaleInt()) {
            float extraFactor = quan->quantScale();
            // for old type 4 models, their quan->quantScale is 0. which will introduce a bug here
            if (oldType4) {
                extraFactor = 1.0f;
            } else if (extraFactor != 1.0f) {
                for (int o=0; o<result->alpha.size(); ++o) {
                    result->alpha.get()[o] *= extraFactor;
                }
            }
        }
    }
    if (forceInt8) {
        return result;
    }
    if (!quan->has_scaleInt() || forceFloat) {
        // Back to float
        if(weightPtr != nullptr) {
            result->weightFloat.set((float *)weightPtr, false);
        } else {
            result->weightFloat.reset((int)weightLength);
        }
        if (nullptr == result->weightFloat.get()) {
            MNN_PRINT("Alloc memory error for extract idst int8/ Back to float\n");
            return nullptr;
        }
        int outputCount = 0;
        if (result->asymmetric) {
            outputCount = result->alpha.size() / 2;
        } else {
            outputCount = result->alpha.size();
        }
        int partWeightSize = (int)weightLength / outputCount;
        for (int o = 0; o < outputCount; ++o) {
            float min = 0.0f;
            float alpha = 0.0f;
            if (result->asymmetric) {
                min = result->alpha.get()[2*o];
                alpha = result->alpha.get()[2*o+1];
            } else {
                alpha = result->alpha.get()[o];
            }
            auto dstW   = result->weightFloat.get() + o * partWeightSize;
            auto srcW   = result->weight.get() + o * partWeightSize;
            for (int v=0; v < partWeightSize; ++v) {
                dstW[v] = (float)srcW[v] * alpha + min;
            }
        }
        result->weight.release();
        result->alpha.release();
    }
    return result;
}