source/backend/metal/MNNMetalContext.mm (423 lines of code) (raw):
//
// MNNMetalContext.mm
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#import "backend/metal/MNNMetalContext.h"
#import "backend/metal/MetalBackend.hpp"
#import "core/Macro.h"
#import <sys/utsname.h>
#import <mach/mach_time.h>
//#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#if MNN_METAL_ENABLED
#include "ShaderMap.hpp"
#include <sstream>
using namespace MNN;
@interface MNNMetalContext ()
// public
@property (strong, nonatomic) id<MTLDevice> device;
@property (assign, nonatomic) BOOL isIphone;
// private
@property (strong, nonatomic) NSDictionary<NSString *, id<MTLComputePipelineState>> *cachesFp32;
@property (strong, nonatomic) NSDictionary<NSString *, id<MTLComputePipelineState>> *cachesFp16;
@end
@implementation MNNMetalContext
static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *, id<MTLComputePipelineState>>* libraryMap, bool usefp16) {
AUTOTIME;
ShaderMap shader;
auto first = shader.search("shader_MetalDefine_metal");
auto second = shader.search("shader_MetalConvolutionActivation_metal");
auto& raw = shader.getRaw();
for (auto& iter : raw) {
std::ostringstream total;
if (iter.first == "shader_MetalDefine_metal") {
continue;
}
if (iter.first == "shader_MetalConvolutionActivation_metal") {
continue;
}
if (!usefp16) {
total << "#define MNN_METAL_FULL_PRECISION 1\n";
} else {
total << "#define MNN_METAL_FULL_PRECISION 0\n";
}
total << first << "\n" << second << "\n" << iter.second;
auto totalString = total.str();
auto totalNSString = [[NSString alloc] initWithUTF8String:totalString.c_str()];
NSError *err = nil;
auto library = [device newLibraryWithSource:totalNSString options:nil error:&err];
if (nil == library) {
if (err) {
printf("Error Key = %s\n", iter.first.c_str());
NSLog(@"Warning: Metallib Library error: %@", err);
}
[libraryMap removeAllObjects];
libraryMap = nil;
return;
}
auto functionNames = [library functionNames];
for(int i=0; i<functionNames.count ; i++) {
id<MTLFunction> function = [library newFunctionWithName:functionNames[i]];
if (!function) {
MNN_ERROR("Create Function in metal error\n");
continue;
}
NSError *error = nil;
auto result = [device newComputePipelineStateWithFunction:function error:&error];
libraryMap[functionNames[i]] = result;
}
}
}
+ (BOOL)isIphone{
struct utsname systemInfo;
uname(&systemInfo);
NSString *deviceString = [NSString stringWithCString:systemInfo.machine encoding:NSASCIIStringEncoding];
NSString *subString = @"iPhone";
NSRange range = [deviceString rangeOfString:subString];
if (range.location != NSNotFound) {
return YES;
}
return NO;
}
- (BOOL) initWithSharedContext:(const MNNMetalSharedContext*)context dev:(id<MTLDevice>)device {
MNN_ASSERT(nullptr != context);
_device = context->device;
NSMutableDictionary* tmp_cachesFp16 = [NSMutableDictionary dictionary];
NSMutableDictionary* tmp_cachesFp32 = [NSMutableDictionary dictionary];
_isIphone = self.class.isIphone;
createLibrary(_device, tmp_cachesFp16, true);
createLibrary(_device, tmp_cachesFp32, false);
_cachesFp16 = [NSDictionary dictionaryWithDictionary:tmp_cachesFp16];
_cachesFp32 = [NSDictionary dictionaryWithDictionary:tmp_cachesFp32];
tmp_cachesFp16 = nil;
tmp_cachesFp32 = nil;
return nil != _device;
}
- (instancetype)init {
self = [super init];
return self;
}
#pragma mark device
- (MTLResourceOptions)optionForAccess:(MNN::MetalAccess)access {
if (access == CPUWriteOnly) {
return MTLResourceOptionCPUCacheModeWriteCombined;
} else if (access == CPUTransparent) {
if (@available(iOS 9.0, *)) {
return MTLResourceStorageModePrivate;
} else {
return MTLResourceOptionCPUCacheModeDefault;
}
} else { // access == CPUReadWrite
return MTLResourceOptionCPUCacheModeDefault;
}
}
- (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size access:(MNN::MetalAccess)access {
return [_device newBufferWithLength:size options:[self optionForAccess:access]];
}
- (id<MTLBuffer>)newDeviceBuffer:(NSUInteger)size bytes:(const void *)bytes access:(MNN::MetalAccess)access {
return [_device newBufferWithBytes:bytes length:size options:[self optionForAccess:access]];
}
- (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name fp16:(BOOL)fp16 {
if (fp16) {
return _cachesFp16[name];
}
return _cachesFp32[name];
}
- (id<MTLComputePipelineState>)pipelineWithSourceOption:(NSString *)source name:(NSString *)name options:(MTLCompileOptions *)options {
NSError *err = nil;
auto library = [_device newLibraryWithSource:source options:options error:&err];
if (err) {
NSLog(@"Warning: pipelineWithSource error: %@, source is: %@", err, source);
}
if (nil == library) {
return nil;
}
id<MTLFunction> function = [library newFunctionWithName:name];
if (nil == function) {
NSLog(@"Warning: Create function failed: %@", name);
return nil;
}
err = nil;
id<MTLComputePipelineState> result = [_device newComputePipelineStateWithFunction:function error:&err];
return result;
}
- (NSUInteger)timeUsed:(id<MTLCommandBuffer>)buffer {
// Get ns precision time
auto start = mach_absolute_time();
[buffer commit];
[buffer waitUntilCompleted];
// NSUInteger time = (NSUInteger)((buffer.GPUEndTime - buffer.GPUStartTime)* 1000000.f);//us
auto end = mach_absolute_time();
return (end-start)/1000;
}
- (id<MTLCommandBuffer>) newCmdBuffer:(MTLSize) localIndex queue:(id<MTLCommandQueue>) cmdqueue {
id<MTLCommandBuffer> cmdBuffer = [cmdqueue commandBuffer]; // create a new command buffer
std::string label = std::to_string((int)localIndex.width) + "_" + std::to_string((int)localIndex.height) + "_" + std::to_string((int)localIndex.depth);
cmdBuffer.label = [NSString stringWithCString:label.c_str() encoding:[NSString defaultCStringEncoding]];
return cmdBuffer;
}
bool getCloseThreadgroup(const std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>>> &tuneMap, const std::vector<uint32_t> &gws, const std::string &kernelName, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>& res){
float minScale = 0.1;
auto iter = tuneMap.find(kernelName);
if(iter == tuneMap.end()){
return false;
}
auto gwsAndLws = iter->second;
int size = gws.size();
uint32_t minPoint = UINT_MAX;
int index = -1;
for(int i = 0; i < gwsAndLws.size(); ++i){
uint32_t point = 0;
for(int j = 0; j < size; ++j){
point += std::abs(static_cast<int>(gws[j]) - static_cast<int>(gwsAndLws[i].first[j]));
}
if(point < minPoint){
index = i;
minPoint = point;
}
}
if(index != -1){
res = gwsAndLws[index].second;
return true;
}
return false;
}
- (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MetalRuntime *) rt shaderName:(std::string) kernelName offsets:(int *) offset_arr queue:(id<MTLCommandQueue>) cmdqueue {
NSUInteger gid_x = threads.width;
NSUInteger gid_y = threads.height;
NSUInteger gid_z = threads.depth;
auto& tunedThreadGroup = rt->getTunedThreadGroup();
std::vector<uint32_t> gws = {(uint32_t)gid_x, (uint32_t)gid_y, (uint32_t)gid_z};
std::pair<std::string, std::vector<uint32_t>> info = std::make_pair(kernelName, gws);
bool exactRes = tunedThreadGroup.find(info) != tunedThreadGroup.end();
std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t> tuneLwsRes;
bool closeRes = false;
if(!exactRes) {
auto& tunedThreadGroupVec = rt->getTunedThreadGroupVec();
if(getCloseThreadgroup(tunedThreadGroupVec, gws, kernelName, tuneLwsRes)){
closeRes = true;
}
} else {
tuneLwsRes = tunedThreadGroup[info];
}
if (exactRes || closeRes) {
//printf("conv2d1x1LocalWSOpt Found! gws:%d %d lws:%d %d\n", gws[0], gws[1], tunedLws[info][0], tunedLws[info][1]);
auto groupNum = std::get<0>(tuneLwsRes);
auto groupSize = std::get<1>(tuneLwsRes);
auto timeCost = std::get<2>(tuneLwsRes);
MTLSize _groupNum = {(NSUInteger)groupNum[0], (NSUInteger)groupNum[1], (NSUInteger)groupNum[2]};
MTLSize _groupSize = {(NSUInteger)groupSize[0], (NSUInteger)groupSize[1], (NSUInteger)groupSize[2]};
std::tuple<MTLSize, MTLSize, NSUInteger> result(_groupNum, _groupSize, (NSUInteger)timeCost);
return result;
}
std::pair<MTLSize, MTLSize> thread;//Grid and ThreadGroup
// set trick by computing
thread = [self computeBestGroupAndLocal:pipeline threads:threads];
if(rt->getTuneLevel() == Heavy) {
count = 50;
}
NSUInteger min_time = UINT_MAX;
if(rt->getTuneLevel() != Never && buffers.count > 0)
{
//get original trick time
{
id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:thread.second queue:cmdqueue];
id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
int loop = count;
while(loop--) {
[encoder setComputePipelineState:pipeline];
for(NSUInteger idx = 0; idx < buffers.count; idx++) {
[encoder setBuffer:[buffers objectAtIndex:idx] offset:offset_arr[idx] atIndex:idx];
}
MNN_ASSERT(thread.second.width >= 1);
MNN_ASSERT(thread.second.height >= 1);
MNN_ASSERT(thread.second.depth >= 1);
[encoder dispatchThreadgroups:thread.first threadsPerThreadgroup:thread.second];
}
[encoder endEncoding];
min_time = [self timeUsed :commamd_buffer];
//MNN_PRINT("orig prit: %d us, %d %d %d\n", min_time, thread.second.width, thread.second.height, thread.second.depth);
}
bool isMuchTime = (min_time > 8000) ? true : false;
NSUInteger magic_l = 1;
NSUInteger magic_z = 16;
NSUInteger magic_y = 4;
NSUInteger magic_x = 4;
if(rt->getTuneLevel() == Heavy) {
magic_l = 2;
magic_z = UINT_MAX;
magic_y = UINT_MAX;
magic_x = UINT_MAX;
} else if(rt->getTuneLevel() == Wide) {
bool isMuchTime = (min_time > 5000) ? true : false;
magic_z = 16;
magic_y = (isMuchTime ? 4 : 16);
magic_x = (isMuchTime ? 4 : 16);
} else if(rt->getTuneLevel() == Normal) {
magic_z = 16;
magic_y = 4;
magic_x = 4;
} else if(rt->getTuneLevel() == Fast) {
magic_z = 4;
magic_y = 4;
magic_x = 4;
}
for(NSUInteger z = 1; z < gid_z * magic_l && z <= magic_z; z *= 4) {
for(NSUInteger y = 1; y < gid_y * magic_l && y <= magic_y; y *= 4) {
for(NSUInteger x = 1; x < gid_x * magic_l && x <= magic_x; x *= 4) {
if(x * y * z <= pipeline.maxTotalThreadsPerThreadgroup) {
if(x==1 && y==1 && z==1) {
continue;
}
MTLSize local = {x, y, z};
MTLSize global = {UP_DIV(gid_x, x), UP_DIV(gid_y, y), UP_DIV(gid_z, z)};
id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:local queue:cmdqueue];
id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
int loop = count;
while(loop--) {
[encoder setComputePipelineState:pipeline];
for(NSUInteger idx = 0; idx < buffers.count; idx++) {
[encoder setBuffer:[buffers objectAtIndex:idx] offset:offset_arr[idx] atIndex:idx];
}
[encoder dispatchThreadgroups:global threadsPerThreadgroup:local];
}
[encoder endEncoding];
auto time = [self timeUsed :commamd_buffer];
if(time < min_time) {
min_time = time;
thread.first = global;
thread.second = local;
}
}
}
}
}
}
//MNN_PRINT("tune prit: %d us, %d %d %d\n", min_time, thread.second.width, thread.second.height, thread.second.depth);
if (tunedThreadGroup.find(info) == tunedThreadGroup.end()) {
//MNN_PRINT("2dLocalWS %d Insert! gws:%d %d, lws:%d %d\n", (int)tunedLws.size(), gws[0], gws[1], lws_prefer[0], lws_prefer[1]);
std::vector<uint32_t> groupNum(3 ,0);
groupNum[0] = thread.first.width;
groupNum[1] = thread.first.height;
groupNum[2] = thread.first.depth;
std::vector<uint32_t> groupSize(3 ,0);
groupSize[0] = thread.second.width;
groupSize[1] = thread.second.height;
groupSize[2] = thread.second.depth;
tunedThreadGroup.insert(std::make_pair(info, std::make_tuple(groupNum, groupSize, (uint32_t)min_time)));
}
return std::make_tuple(thread.first, thread.second, min_time);
}
- (NSUInteger)PipelinetimeUsed: (id<MTLComputePipelineState>)pipeline global:(MTLSize)globals local:(MTLSize)locals loop:(NSUInteger)count buffer:(NSArray *)buffers queue:(id<MTLCommandQueue>) cmdqueue{
NSUInteger time = 0;
MTLSize local_size = {locals.width, locals.height, locals.depth};
MTLSize global_size = {globals.width, globals.height, globals.depth};
id<MTLCommandBuffer> commamd_buffer = [self newCmdBuffer:local_size queue:cmdqueue];
id<MTLComputeCommandEncoder> encoder = [commamd_buffer computeCommandEncoder];
int loop = count;
while(loop--) {
[encoder setComputePipelineState:pipeline];
for(NSUInteger idx = 0; idx < buffers.count; idx++) {
[encoder setBuffer:[buffers objectAtIndex:idx] offset:0 atIndex:idx];
}
[encoder dispatchThreadgroups:global_size threadsPerThreadgroup:local_size];
}
[encoder endEncoding];
time = [self timeUsed :commamd_buffer];
return time;
}
static NSUInteger smallest_log2(NSUInteger integer) {
if (integer == 0)
return 0;
NSUInteger power = 0;
while ((integer & 0b1) == 0) {
integer = integer >> 1;
power++;
}
return power;
}
- (std::pair<MTLSize, MTLSize>)computeBestGroupAndLocal:(id<MTLComputePipelineState>) bw threads:(MTLSize)t {
auto local = [self computeBestGroup:bw threads:t];
local.width = ALIMAX(local.width, 1);
local.height = ALIMAX(local.height, 1);
local.depth = ALIMAX(local.depth, 1);
auto globalSize = MTLSizeMake(UP_DIV(t.width, local.width), UP_DIV(t.height, local.height), UP_DIV(t.depth, local.depth));
return std::make_pair(globalSize, local);
}
- (MTLSize)computeBestGroup:(id<MTLComputePipelineState>) bw threads:(MTLSize)t {
auto pwarp = smallest_log2(bw.threadExecutionWidth);
auto px = smallest_log2(t.width), sx = (NSUInteger)ceil(log2(t.width));
auto py = smallest_log2(t.height), sy = (NSUInteger)ceil(log2(t.height));
// accurately match on x
if (px >= pwarp) {
return {bw.threadExecutionWidth, 1, 1};
}
// accurately match on xy
else if (px + py >= pwarp && sx < pwarp / 2) {
NSUInteger x = pow(2, px);
return {x, bw.threadExecutionWidth / x, 1};
}
// similarly match on x
else if (sx >= pwarp) {
return {bw.threadExecutionWidth, 1, 1};
}
// similarly match on xy
else if (sx + sy >= pwarp) {
NSUInteger x = pow(2, sx);
return {x, bw.threadExecutionWidth / x, 1};
}
// on xyz (for most shaders do not protect gid.z, z axis must be accurately match)
auto pz = smallest_log2(t.depth);
auto sz = pz;
if (px + py + pz >= pwarp) {
NSUInteger x = pow(2, px), y = pow(2, py);
return {x, y, bw.threadExecutionWidth / x / y};
} else if (sx + sy + sz >= pwarp) {
NSUInteger x = pow(2, sx), z = pow(2, MIN(sz, pwarp - sx));
return {x, bw.threadExecutionWidth / x / z, z};
} else {
NSUInteger z = pow(2, sz);
return {t.width, t.height, z};
}
}
#if MNN_METAL_DEBUG
#pragma mark debug
- (void)printTensor:(const Tensor *)tensor {
tensor->print();
}
template <typename T>
void printBuffer(const void *content, unsigned long bytes, const char *fmt) {
const T *data = (const T *)content;
for (int i = 0; i < bytes / sizeof(T); i++) {
if (i % 4 == 0)
printf("%3d > ", i / 4);
printf(fmt, data[i]);
printf((i + 1) % 4 == 0 ? ",\n" : " ");
}
}
- (void)printBuffer:(halide_buffer_t)buffer {
if (buffer.host) {
[self printBytes:buffer.host
length:buffer.dim[0].stride * buffer.dim[0].extent * buffer.type.bytes()
type:buffer.type.code
bits:buffer.type.bits];
} else if (buffer.type.code == halide_type_float) {
[self printBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)buffer.device)->getBuffer() type:buffer.type.code bits:16];
} else {
[self printBuffer:(id<MTLBuffer>)((MetalRuntimeAllocator::MetalBufferAlloc *)buffer.device)->getBuffer() type:buffer.type.code bits:buffer.type.bits];
}
}
- (void)printBuffer:(id<MTLBuffer>)buffer type:(halide_type_code_t)type bits:(int)bits {
[self printBytes:buffer.contents length:buffer.length type:type bits:bits];
}
- (void)printBytes:(const void *)bytes length:(NSUInteger)length type:(halide_type_code_t)type bits:(int)bits {
if (type == halide_type_int) {
if (bits == 8) { // int8
printBuffer<int8_t>(bytes, length, "%3d");
} else if (bits == 16) { // int16
printBuffer<int16_t>(bytes, length, "%d");
} else if (bits == 32) { // int32
printBuffer<int32_t>(bytes, length, "%d");
}
} else if (type == halide_type_uint) {
if (bits == 8) { // uint8
printBuffer<uint8_t>(bytes, length, "%3d");
} else if (bits == 16) { // uint16
printBuffer<uint16_t>(bytes, length, "%d");
} else if (bits == 32) { // uint32
printBuffer<uint32_t>(bytes, length, "%d");
}
} else if (type == halide_type_float) {
if (bits == 16) { // half
printBuffer<__fp16>(bytes, length, "%.4f");
} else { // float
printBuffer<float>(bytes, length, "%.4f");
}
}
}
- (void)printEncoder:(id<MTLCommandEncoder>)encoder {
printf("[METAL] %s encoded.\n", encoder.label.UTF8String);
}
#endif
@end
#endif /* MNN_METAL_ENABLED */