source/backend/metal/MetalDeconvolution.mm (244 lines of code) (raw):
//
// MetalDeconvolution.mm
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalDeconvolution.hpp"
#import "core/ConvolutionCommon.hpp"
#import "backend/metal/MNNMetalContext.h"
#import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
struct deconv_constants {
int input_width;
int input_height;
int input_size;
int input_slice;
int output_width;
int output_height;
int output_size;
int output_slice;
int kernel_x;
int kernel_y;
int kernel_size;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
int delta_ky;
int delta_kx;
int delta_iy;
int delta_ix;
int batch;
int activation;
};
static int leastCommonMultiple(int m, int n) {
int a = m, b = n;
while(a != b){
if (a > b){
a = a - b;
} else {
b = b - a;
}
}
return m * n / a;
}
template <typename FType, typename TType>
static void weightForDeconv(int group, int oc, int ic, int kh, int kw,
const FType *src, uint8_t* dstOrigin) {
auto goc = oc / group;
auto gic = ic / group;
auto goc_4 = UP_DIV(goc, 4);
auto gic_4 = UP_DIV(gic, 4);
auto dst = (TType *)dstOrigin;
for (int g = 0; g < group; g++) {
for (int i = 0; i < gic; i++) {
for (int o = 0; o < goc; o++) {
for (int h = 0; h < kh; h++) {
for (int w = 0; w < kw; w++) {
auto zo = o / 4, ro = o % 4;
auto zi = i / 4, ri = i % 4;
// to [g][o/4][i/4][h][w][16]
dst[(g * goc_4 * gic_4 * kh * kw + zo * gic_4 * kh * kw + zi * kh * kw + h * kw + w) * 16 +
ro * 4 + ri] =
// from [g][i][o][h][w]
// src[ g * goc * gic * kh * kw + i * goc * kh * kw + o *
// kh * kw + h * kw + w];
*src++;
}
}
}
}
}
}
template <typename FType, typename TType>
static void weightForDepthwise(int group, int kh, int kw, const FType *src, uint8_t* dstOrigin) {
auto dst = (TType *)dstOrigin;
for (int g = 0; g < group; g++) {
auto z = g / 4, r = g % 4;
auto z_dst = dst + z * kh * kw * 4 + r;
for (int h = 0; h < kh; h++) {
for (int w = 0; w < kw; w++) {
// to [g/4][h][w][4]
// from [g][h][w]
// dst[(z * kh * kw + h * kw + w) * 4 + r] =
// src[ g * kh * kw + h * kw + w];
z_dst[(h * kw + w) * 4] = *src++;
}
}
}
}
template <typename TType>
void weightForDeconv(std::shared_ptr<MNN::Tensor> t, bool depthwise, const Convolution2D *deconv,
ConvolutionCommon::Int8Common *qnt) {
auto common = deconv->common();
auto kw = common->kernelX();
auto kh = common->kernelY();
auto group = common->group();
auto oc = common->outputCount();
auto size = qnt ? qnt->weightFloat.size() : deconv->weight()->size();
auto buffer = MetalBackend::getBuffer(t.get());
auto ic = size / kw / kh / (oc / group);
auto dst = (uint8_t*)[buffer.first contents] + buffer.second;
if (depthwise) {
weightForDepthwise<float, TType>(group, kh, kw,
qnt ? qnt->weightFloat.get() : deconv->weight()->data(), dst);
} else {
weightForDeconv<float, TType>(group, oc, ic, kh, kw,
qnt ? qnt->weightFloat.get() : deconv->weight()->data(), dst);
}
}
static std::shared_ptr<MNN::Tensor> biasForDeconv(Backend *backend, const Convolution2D *deconv, bool fp16) {
auto bias = deconv->bias();
auto oc = deconv->common()->outputCount();
int bytes = 4;
if (fp16) {
bytes = 2;
}
auto length = UP_DIV(oc, 4) * 4;
std::shared_ptr<MNN::Tensor> t(MNN::Tensor::createDevice<float>({length}));
auto res = backend->onAcquireBuffer(t.get(), Backend::STATIC);
if (!res) {
return nullptr;
}
auto buffer = MetalBackend::getBuffer(t.get());
auto dstO = (uint8_t*)[buffer.first contents] + buffer.second;
auto src = bias->data();
if (fp16) {
auto dst = (__fp16 *)dstO;
for (int i = 0; i < oc; i++) {
dst[i] = src[i];
}
} else {
::memcpy(dstO, src, oc * sizeof(float));
}
return t;
}
MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : MetalExecution(backend) {
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto mtbn = static_cast<MetalBackend *>(backend);
auto deconv = op->main_as_Convolution2D();
auto common = deconv->common();
mOp = op;
mDepthwise = op->type() == MNN::OpType_DeconvolutionDepthwise;
mPadMode = common->padMode();
// forcy downgrade to float like what CPU does
std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
if (deconv->quanParameter()) {
qnt = ConvolutionCommon::load(op, backend, true);
}
auto kw = common->kernelX();
auto kh = common->kernelY();
auto group = common->group();
auto oc = common->outputCount();
auto size = qnt ? qnt->weightFloat.size() : deconv->weight()->size();
auto ic = size / kw / kh / (oc / group);
auto goc = oc / group;
auto gic = ic / group;
auto goc_4 = UP_DIV(goc, 4);
auto gic_4 = UP_DIV(gic, 4);
int weightSize = group * goc_4 * gic_4 * kw * kh * 16;
if (mDepthwise) {
weightSize = UP_DIV(group, 4) * 4 * kw * kh;
}
mWeight.reset(MNN::Tensor::createDevice<float>({weightSize}));
bool res = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC);
if (!res) {
mValid = false;
return;
}
auto weightBuffer = MetalBackend::getBuffer(mWeight.get());
auto ptr = (uint8_t*)weightBuffer.first.contents + weightBuffer.second;
if (mtbn->useFp16InsteadFp32()) {
::memset(ptr, 0, weightSize * sizeof(int16_t));
weightForDeconv<__fp16>(mWeight, mDepthwise, deconv, qnt.get());
} else {
::memset(ptr, 0, weightSize * sizeof(float));
weightForDeconv<float>(mWeight, mDepthwise, deconv, qnt.get());
}
mBias = biasForDeconv(backend, deconv, mtbn->useFp16InsteadFp32());
if (nullptr == mBias) {
mValid = false;
return;
}
if (mDepthwise) {
mPipeline = [context pipelineWithName:@"deconv_depthwise" fp16:mtbn->useFp16InsteadFp32()];
} else {
mPipeline = [context pipelineWithName:@"deconv" fp16:mtbn->useFp16InsteadFp32()];
}
mConstBuffer = [context newDeviceBuffer:sizeof(deconv_constants) access:CPUWriteOnly];
auto param = (deconv_constants*)mConstBuffer.contents;
mGroup = common->group();
param->kernel_x = common->kernelX();
param->kernel_y = common->kernelY();
param->kernel_size = common->kernelX() * common->kernelY();
param->stride_x = common->strideX();
param->stride_y = common->strideY();
param->dilation_x = common->dilateX();
param->dilation_y = common->dilateY();
param->activation = common->relu() ? 1 : (common->relu6() ? 2 : 0);
auto deltaKy = leastCommonMultiple(common->dilateY(), common->strideY()) / common->dilateY();
auto deltaKx = leastCommonMultiple(common->dilateX(), common->strideX()) / common->dilateX();
param->delta_kx = deltaKx;
param->delta_ky = deltaKy;
param->delta_iy = deltaKy * common->dilateY() / common->strideY();
param->delta_ix = deltaKx * common->dilateX() / common->strideX();
}
ErrorCode MetalDeconvolution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
auto input = inputs[0], output = outputs[0];
int iw = input->width(), ih = input->height(), iz = UP_DIV(input->channel(), 4);
int ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4);
int ob = output->batch();
auto pad = ConvolutionCommon::convolutionTransposePad(input, output, mOp->main_as_Convolution2D()->common());
const int padX = pad.first;
const int padY = pad.second;
// const buffer
auto param = (deconv_constants*)mConstBuffer.contents;
param->input_width = iw;
param->input_height = ih;
param->input_size = iw * ih;
param->input_slice = iz;
param->output_width = ow;
param->output_height = oh;
param->output_size = ow * oh;
param->output_slice = oz;
param->batch = ob;
param->pad_x = padX;
param->pad_y = padY;
mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger) ow, (NSUInteger)oh, (NSUInteger)oz * ob)];
return NO_ERROR;
}
void MetalDeconvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(output, encoder, 1);
[encoder setBuffer:mConstBuffer offset:0 atIndex:2];
MetalBackend::setTensor(mWeight.get(), encoder, 3);
MetalBackend::setTensor(mBias.get(), encoder, 4);
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
}
class MetalDeconvolutionCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
if (inputs.size() > 1) {
MNN_PRINT("multi input deconv for metal not supoort!\n");
return nullptr;
}
return new MetalDeconvolution(backend, op);
}
};
REGISTER_METAL_OP_CREATOR(MetalDeconvolutionCreator, OpType_Deconvolution);
REGISTER_METAL_OP_CREATOR(MetalDeconvolutionCreator, OpType_DeconvolutionDepthwise);
} // namespace MNN
#endif /* MNN_METAL_ENABLED */