source/backend/metal/MetalBinary.mm (126 lines of code) (raw):

// // MetalBinary.mm // MNN // // Created by MNN on 2019/01/30. // Copyright © 2018, Alibaba Group Holding Limited // #import "backend/metal/MetalBinary.hpp" #import "backend/metal/MNNMetalContext.h" #import "core/Macro.h" #import "backend/metal/MetalBackend.hpp" #if MNN_METAL_ENABLED namespace MNN { MetalBinary::MetalBinary(Backend *backend, id<MTLComputePipelineState> pipeline, int activationType) : MetalExecution(backend) { auto mtbn = static_cast<MetalBackend *>(backend); auto context = (__bridge MNNMetalContext *)mtbn->context(); mConstBuffer = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; mPipeline = pipeline; mActivationType = activationType; } ErrorCode MetalBinary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { auto backend = static_cast<MetalBackend *>(this->backend()); auto context = (__bridge MNNMetalContext *)backend->context(); auto input0 = inputs[0], input1 = inputs[1], output = outputs[0]; const int input0_data_count = TensorUtils::getRawSize(input0); const int input1_data_count = TensorUtils::getRawSize(input1); int outdatacount = output->elementSize(); ((int *)mConstBuffer.contents)[0] = input0_data_count == 1 ? 0 : 1; ((int *)mConstBuffer.contents)[1] = input1_data_count == 1 ? 0 : 1; ((int *)mConstBuffer.contents)[2] = outdatacount; ((int *)mConstBuffer.contents)[3] = mActivationType; mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(outdatacount, 1, 1)]; return NO_ERROR; } void MetalBinary::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) { auto input0 = inputs[0], input1 = inputs[1], output = outputs[0]; [encoder setComputePipelineState:mPipeline]; MetalBackend::setTensor(input0, encoder, 0); MetalBackend::setTensor(input1, encoder, 1); MetalBackend::setTensor(output, encoder, 2); [encoder setBuffer:mConstBuffer offset:0 atIndex:3]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; } #define CHECK(t, i) if (originOp == t) return i; NSString* MetalBinary::convert(int originOp, bool inputFloat) { if (BinaryOpOperation_MOD == originOp) { if (inputFloat) { return @"fmod((float)V0,(float)V1)"; } /** auto res = x % y; if ((res < 0 && y > 0) || (res > 0 && y < 0)) { res += y; } */ return @"select(V0%V1,(V0%V1)+V1,(V0%V1<0&&V1>0)||(V0%V1>0&&V1<0))"; } CHECK(BinaryOpOperation_ADD, @"V0+V1"); CHECK(BinaryOpOperation_ATAN2, @"atan2(V0,V1)"); CHECK(BinaryOpOperation_SUB, @"V0-V1"); CHECK(BinaryOpOperation_MUL, @"V0*V1"); CHECK(BinaryOpOperation_FLOORMOD, @"V0-floor((float)V0/(float)V1)*V1"); CHECK(BinaryOpOperation_FLOORDIV, @"floor((float)V0/(float)V1)"); CHECK(BinaryOpOperation_MINIMUM, @"min(V0,V1)"); CHECK(BinaryOpOperation_MAXIMUM, @"max(V0,V1)"); CHECK(BinaryOpOperation_DIV, @"V1==0?0:V0/V1"); CHECK(BinaryOpOperation_REALDIV, @"V1==0?0:V0/V1"); CHECK(BinaryOpOperation_POW, @"pow(V0,V1)"); CHECK(BinaryOpOperation_SquaredDifference, @"(V0-V1)*(V0-V1)"); CHECK(BinaryOpOperation_EQUAL, @"(V0==V1)?1:0"); CHECK(BinaryOpOperation_LESS, @"(V0<V1)?1:0"); CHECK(BinaryOpOperation_LESS_EQUAL, @"(V0<=V1)?1:0"); CHECK(BinaryOpOperation_GREATER, @"(V0>V1)?1:0"); CHECK(BinaryOpOperation_GREATER_EQUAL, @"(V0>=V1)?1:0"); CHECK(BinaryOpOperation_NOTEQUAL, @"(V0!=V1)?1:0"); CHECK(BinaryOpOperation_LOGICALOR, @"V0||V1"); CHECK(BinaryOpOperation_BITWISE_AND, @"V0&V1"); CHECK(BinaryOpOperation_BITWISE_OR, @"V0|V1"); CHECK(BinaryOpOperation_BITWISE_XOR, @"V0^V1"); return nil; } static const char* gBinaryTemplate = R"metal( #include <metal_stdlib> #include <simd/simd.h> using namespace metal; kernel void main0(const device T0 *in0 [[buffer(0)]], const device T1 *in1 [[buffer(1)]], device T2 *out [[buffer(2)]], constant int4& s [[buffer(3)]], uint gid [[thread_position_in_grid]]) { if ((int)gid >= s.z) return; auto V0 = in0[s.x * int(gid)]; auto V1 = in1[s.y * int(gid)]; auto val = CUSTOM; if(s.w == 1) { val = (val < (T2)0 ? (T2)0 : val); } out[int(gid)] = val; } )metal"; class MetalBinaryCreator : public MetalBackend::Creator { public: virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const { auto binaryop = op->main_as_BinaryOp(); auto mtbn = static_cast<MetalBackend *>(backend); NSString* T2 = MetalCast::getScalarType(outputs[0]->getType(), mtbn->useFp16InsteadFp32()); NSString* T0 = MetalCast::getScalarType(inputs[0]->getType(), mtbn->useFp16InsteadFp32()); NSString* T1 = MetalCast::getScalarType(inputs[1]->getType(), mtbn->useFp16InsteadFp32()); std::vector<std::string> keys = { std::string([T0 UTF8String]), std::string([T1 UTF8String]), std::string([T2 UTF8String]), std::to_string(binaryop->opType()), "binary" }; auto pipeline = mtbn->runtime()->findPipeline(keys); if (nil == pipeline) { MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init]; NSString* custom = MetalBinary::convert(binaryop->opType(), inputs[0]->getType().code == halide_type_float); if (nil == custom) { MNN_ERROR("Metal Don't support binary - %d \n", binaryop->opType()); return nullptr; } compileOptions.preprocessorMacros = @{ @"T0" : T0, @"T1" : T1, @"T2" : T2, @"CUSTOM" : custom, }; pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryTemplate, "main0", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { MNN_ERROR("Make Binary shader error\n"); return nullptr; } return new MetalBinary(backend, pipeline, binaryop->activationType()); } }; REGISTER_METAL_OP_CREATOR(MetalBinaryCreator, OpType_BinaryOp); } // namespace MNN #endif /* MNN_METAL_ENABLED */