source/backend/metal/MetalMatMul.mm (98 lines of code) (raw):
//
// MetalMatMul.mm
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MetalMatMul.hpp"
#import "backend/metal/MNNMetalContext.h"
#import "core/Macro.h"
#import "core/Macro.h"
#import "backend/metal/MetalBackend.hpp"
#if MNN_METAL_ENABLED
namespace MNN {
struct matP {
int size[4];
int stride[4];
};
MetalMatMul::MetalMatMul(Backend *backend, const MatMul *matmul, bool withBias) : MetalExecution(backend) {
mTransposeA = matmul->transposeA();
mTransposeB = matmul->transposeB();
auto mkbn = static_cast<MetalBackend *>(backend);
mConstBuffer = mkbn->getConstBuffer(sizeof(matP));
auto context = (__bridge MNNMetalContext *)mkbn->context();
if (withBias) {
mPipeline = [context pipelineWithName:@"matmul_bias" fp16:mkbn->useFp16InsteadFp32()];
} else {
mPipeline = [context pipelineWithName:@"matmul" fp16:mkbn->useFp16InsteadFp32()];
}
}
MetalMatMul::~MetalMatMul() {
auto mkbn = static_cast<MetalBackend *>(backend());
mkbn->returnConstBuffer(mConstBuffer);
}
ErrorCode MetalMatMul::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
Tensor* C = outputs[0];
auto w0 = inputs[0]->length(1);
auto h0 = inputs[0]->length(0);
auto e = C->length(0);
auto h = C->length(1);
auto l = w0;
if (mTransposeA) {
l = h0;
}
matP buffer;
buffer.size[0] = h;
buffer.size[1] = e;
buffer.size[2] = l;
if (mTransposeA) {
buffer.stride[0] = 1;
buffer.stride[1] = e;
} else {
buffer.stride[0] = l;
buffer.stride[1] = 1;
}
if (mTransposeB) {
buffer.stride[2] = l;
buffer.stride[3] = 1;
} else {
buffer.stride[2] = 1;
buffer.stride[3] = h;
}
::memcpy(mConstBuffer.contents, &buffer, sizeof(matP));
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
mThreads = [context computeBestGroupAndLocal:mPipeline threads: MTLSizeMake(h, e, 1)];
return NO_ERROR;
}
void MetalMatMul::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
auto backend = static_cast<MetalBackend *>(this->backend());
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
auto input0 = inputs[0], input1 = inputs[1], output = outputs[0];
Tensor* C = outputs[0];
auto e = C->length(0);
auto h = C->length(1);
if (inputs.size() > 2) {
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input0, encoder, 0);
MetalBackend::setTensor(input1, encoder, 1);
MetalBackend::setTensor(inputs[2], encoder, 2);
MetalBackend::setTensor(output, encoder, 3);
[encoder setBuffer:mConstBuffer offset:0 atIndex:4];
} else {
[encoder setComputePipelineState:mPipeline];
MetalBackend::setTensor(input0, encoder, 0);
MetalBackend::setTensor(input1, encoder, 1);
MetalBackend::setTensor(output, encoder, 2);
[encoder setBuffer:mConstBuffer offset:0 atIndex:3];
}
[encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second];
}
class MetalMatMulCreator : public MetalBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op,
Backend *backend, const std::vector<Tensor *>& outputs) const override {
if(inputs.size() < 2) {
MNN_PRINT("metal not support matmul inpt size less than 2\n");
return nullptr;
}
return new MetalMatMul(backend, op->main_as_MatMul(), inputs.size() > 2);
}
};
REGISTER_METAL_OP_CREATOR(MetalMatMulCreator, OpType_MatMul);
} // namespace MNN
#endif /* MNN_METAL_ENABLED */