source/backend/metal/MetalScale.mm (67 lines of code) (raw):
//
// MetalScale.mm
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalScale.hpp"
#import "backend/metal/MNNMetalContext.h"
#import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
MetalScale::MetalScale(Backend *backend, const Scale *scale) : MetalExecution(backend) {
auto mtbn = static_cast<MetalBackend *>(backend);
auto bufferAlloc = mtbn->getStaticBufferPool();
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto channel4 = UP_DIV(scale->channels(), 4) * 4;
mBiasOffset = channel4 / 4;
mScaleBias = bufferAlloc->alloc(2 * channel4 * sizeof(float));
if (mScaleBias.first == nullptr) {
mValid = false;
return;
}
auto scalePtr = MetalBackend::getMemPtr(mScaleBias);
::memset(scalePtr, 0, 2 * channel4 * sizeof(float));
::memcpy(scalePtr, scale->scaleData()->data(), scale->channels() * sizeof(float));
auto biasPtr = scalePtr + channel4 * sizeof(float);
if (nullptr != scale->biasData()) {
::memcpy(biasPtr, scale->biasData()->data(), scale->channels() * sizeof(float));
}
mConst = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
mPipeline = [context pipelineWithName:@"scale_ca" fp16:mtbn->useFp16InsteadFp32()];
}
MetalScale::~MetalScale() {
auto mtbn = static_cast<MetalBackend *>(backend());
auto bufferAlloc = mtbn->getStaticBufferPool();
if (nullptr != mScaleBias.first) {
bufferAlloc->free(mScaleBias);
}
}
ErrorCode MetalScale::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)backend->context();
auto output = outputs[0];
// shape
int w = output->width();
int h = output->height();
int c = output->channel();
int z = UP_DIV(c, 4);
((int *)mConst.contents)[0] = w*h;
((int *)mConst.contents)[1] = z;
((int *)mConst.contents)[2] = output->batch();
((int *)mConst.contents)[3] = mBiasOffset;
mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(w*h, z * outputs[0]->batch(), 1)];
return NO_ERROR;
}
void MetalScale::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto input = inputs[0], output = outputs[0];
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input, encoder, 0);
MetalBackend::setTensor(output, encoder, 1);
[encoder setBuffer:mConst offset:0 atIndex:2];
MetalBackend::setMem(mScaleBias, encoder, 3);
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
}
class MetalScaleCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const {
return new MetalScale(backend, op->main_as_Scale());
}
};
REGISTER_METAL_OP_CREATOR(MetalScaleCreator, OpType_Scale);
} // namespace MNN
#endif /* MNN_METAL_ENABLED */