source/backend/metal/MetalUnary.mm (163 lines of code) (raw):

// // MetalUnary.mm // MNN // // Created by MNN on 2019/01/30. // Copyright © 2018, Alibaba Group Holding Limited // #import "MetalCast.hpp" #import "core/Macro.h" #import "MNNMetalContext.h" #import "MetalBackend.hpp" #import "MetalUnary.hpp" #if MNN_METAL_ENABLED namespace MNN { static const char* gUnaryTemplate = R"metal( #include <metal_stdlib> #include <simd/simd.h> using namespace metal; struct unary_shape { int width; int height; int size; }; static inline float4 MNNEXP(float4 tmp) { tmp = clamp(tmp, (float4)-87.0, (float4)87.0); return exp(tmp); } static inline float4 MNNTANH(float4 value) { float4 tmp = MNNEXP((float4)(2.0)*value); return (tmp-(float4)1.0)/(tmp+(float4)1.0); } static inline float4 neg(float4 value) { return -value; } static inline float4 square(float4 value) { return value * value; } static inline float4 expm1(float4 value) {return MNNEXP(value) - 1;} static inline float4 reciprocal(float4 value) {return 1.0/(value);} static inline float4 sigmoid(float4 value) {return 1.f / (1.f + MNNEXP(-value));} static inline float4 silu(float4 value) {return value / (1.f + MNNEXP(-value));} static inline float4 log1p(float4 value) {return log(1.f + value);} static inline float4 hardswish(float4 value) { return (float4)(1.0/6.0) * (value * min(max(value+(float4)3, 0), (float4)6)); } static inline float4 gelu(float4 value) { float4 temp = (float4)0.044715 * value * value * value; temp = (float4)0.79788458 * (temp + value); temp = clamp(temp, (float4)-5.0, (float4)5.0); float4 result = ((float4)1.0 + MNNTANH(temp)) * value * (float4)0.5; return result; } kernel void main0(const device T *in [[buffer(0)]], \ device T *out [[buffer(1)]], \ device unary_shape& s [[buffer(2)]], \ uint3 gid [[thread_position_in_grid]]) { \ if (gid.x < (uint)s.width) { \ int off = gid.z * s.size + gid.y * s.width + gid.x; \ out[off] = (T)(FUNC((float4)(in[off]))); \ } \ } )metal"; static NSString *kernelForType(UnaryOpOperation type) { #define op_case(type, imp) \ case UnaryOpOperation_##type: \ return @#imp switch (type) { op_case(ABS, abs); op_case(NEG, neg); op_case(FLOOR, floor); op_case(CEIL, ceil); op_case(ROUND, round); op_case(SQUARE, square); op_case(SQRT, sqrt); op_case(RSQRT, rsqrt); op_case(EXP, MNNEXP); op_case(EXPM1, expm1); op_case(LOG, log); op_case(SIN, sin); op_case(COS, cos); op_case(TAN, tan); op_case(TANH, MNNTANH); op_case(SIGMOID, sigmoid); op_case(SILU, silu); op_case(ASIN, asin); op_case(ACOS, acos); op_case(ATAN, atan); op_case(SIGN, sign); op_case(RECIPROCAL, reciprocal); op_case(LOG1P, log1p); op_case(ACOSH, acosh); op_case(COSH, cosh); op_case(SINH, sinh); op_case(ASINH, asinh); op_case(ATANH, atanh); op_case(HARDSWISH, hardswish); op_case(GELU, gelu); op_case(GELU_STANDARD, gelu); default: FUNC_PRINT_ALL(EnumNameUnaryOpOperation(type), s); return nil; } } MetalUnary::MetalUnary(Backend *backend, id<MTLComputePipelineState> pipeline) : MetalExecution(backend) { auto mtbn = static_cast<MetalBackend *>(backend); auto context = (__bridge MNNMetalContext *)mtbn->context(); mConstBuffer = [context newDeviceBuffer:3 * sizeof(int) access:CPUWriteOnly]; mPipeline = pipeline; } ErrorCode MetalUnary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { auto mtbn = static_cast<MetalBackend *>(backend()); auto context = (__bridge MNNMetalContext *)mtbn->context(); auto input = inputs[0]; auto element = input->elementSize(); auto sizeDiv4 = UP_DIV(element, 4); ((int *)mConstBuffer.contents)[0] = sizeDiv4; mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake(sizeDiv4, 1, 1)]; return NO_ERROR; } void MetalUnary::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) { auto input = inputs[0], output = outputs[0]; [encoder setComputePipelineState:mPipeline]; MetalBackend::setTensor(input, encoder, 0); MetalBackend::setTensor(output, encoder, 1); [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; } class MetalUnaryCreator : public MetalBackend::Creator { public: virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *>& outputs) const { UnaryOpOperation optype; if (op->type() == OpType_TanH) { optype = UnaryOpOperation_TANH; } else if (op->type() == OpType_Sigmoid) { optype = UnaryOpOperation_SIGMOID; } else { optype = op->main_as_UnaryOp()->opType(); } if (UnaryOpOperation_ERF == optype || UnaryOpOperation_ERFC == optype || UnaryOpOperation_ERFINV == optype) { return nullptr; } auto kernel = kernelForType(optype); if (nil == kernel) { return nullptr; } auto mtbn = static_cast<MetalBackend *>(backend); NSString* T = MetalCast::getVecType(outputs[0]->getType(), mtbn->useFp16InsteadFp32()); std::vector<std::string> keys = { std::string([T UTF8String]), std::string([kernel UTF8String]), "unary" }; auto pipeline = mtbn->runtime()->findPipeline(keys); if (nil == pipeline) { MTLCompileOptions *compileOptions = [[MTLCompileOptions alloc] init]; compileOptions.preprocessorMacros = @{ @"T" : T, @"FUNC" : kernel, }; pipeline = mtbn->makeComputePipelineWithSourceOption(gUnaryTemplate, "main0", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { MNN_ERROR("Make Unary shader error\n"); return nullptr; } return new MetalUnary(backend, pipeline); } }; REGISTER_METAL_OP_CREATOR(MetalUnaryCreator, OpType_UnaryOp); REGISTER_METAL_OP_CREATOR(MetalUnaryCreator, OpType_TanH); REGISTER_METAL_OP_CREATOR(MetalUnaryCreator, OpType_Sigmoid); } // namespace MNN #endif /* MNN_METAL_ENABLED */