tools/converter/source/caffe/LSTM.cpp (56 lines of code) (raw):
//
// LSTM.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "OpConverter.hpp"
class LSTM : public OpConverter {
public:
void run(MNN::OpT* dstOp, const caffe::LayerParameter& parameters, const caffe::LayerParameter& weight);
virtual MNN::OpType opType() {
return MNN::OpType_LSTM;
}
virtual MNN::OpParameter type() {
return MNN::OpParameter_LSTM;
}
};
void LSTM::run(MNN::OpT* dstOp, const caffe::LayerParameter& parameters, const caffe::LayerParameter& weight) {
MNN::LSTMT* lstmt = new MNN::LSTMT;
dstOp->main.value = lstmt;
auto lstmcaffe = parameters.lstm_param();
lstmt->outputCount = lstmcaffe.num_output();
lstmt->clippingThreshold = lstmcaffe.clipping_threshold();
int SizeI = 0, SizeH = 0;
// blob[0] weight_i blob[1] weight_h blob[2] bias
auto w = &weight;
int blobCnt = ((caffe::LayerParameter*)w)->blobs().size();
if (blobCnt >= 1) {
const caffe::BlobProto& wi = ((caffe::LayerParameter*)w)->blobs(0);
SizeI = wi.data_size();
if (SizeI > 0) {
lstmt->weightI = std::unique_ptr<MNN::BlobT>(new MNN::BlobT);
lstmt->weightI->dims.push_back(SizeI);
lstmt->weightI->float32s.resize(SizeI);
memcpy(lstmt->weightI->float32s.data(), wi.data().data(), sizeof(float) * SizeI);
}
}
if (blobCnt >= 2) {
const caffe::BlobProto& wh = ((caffe::LayerParameter*)w)->blobs(1);
SizeH = wh.data_size();
if (SizeH > 0) {
lstmt->weightH = std::unique_ptr<MNN::BlobT>(new MNN::BlobT);
lstmt->weightH->dims.push_back(SizeH);
lstmt->weightH->float32s.resize(SizeH);
memcpy(lstmt->weightH->float32s.data(), wh.data().data(), sizeof(float) * SizeH);
}
}
if (blobCnt >= 3) {
const caffe::BlobProto& b = ((caffe::LayerParameter*)w)->blobs(2);
int biasCnt = b.data_size();
if (biasCnt > 0) {
lstmt->bias = std::unique_ptr<MNN::BlobT>(new MNN::BlobT);
lstmt->bias->dims.push_back(biasCnt);
lstmt->bias->float32s.resize(biasCnt);
memcpy(lstmt->bias->float32s.data(), b.data().data(), sizeof(float) * biasCnt);
}
}
lstmt->weightSize = SizeI > SizeH ? SizeH : SizeI;
}
static OpConverterRegister<LSTM> a("Lstm");
static OpConverterRegister<LSTM> _a("OCRLSTM");
static OpConverterRegister<LSTM> _sa("OCRLSTMQ");
static OpConverterRegister<LSTM> __b("CuDNNLstmForward");