csrc/mps_ops.mm (54 lines of code) (raw):

#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 static inline MPSGraph* get_graph() { static MPSGraph* cur = nil; if(!cur) { cur = [[MPSGraph alloc] init]; } return cur; } static inline id<MTLDevice> get_device() { NSError *error = nil; static id<MTLDevice> device = nil; if(!device) { device = MTLCreateSystemDefaultDevice(); } if(!device) { NSLog(@"Failed to get MPS device"); abort(); } return device; } static inline id<MTLLibrary> get_library() { NSError *error = nil; static id<MTLLibrary> library = nil; if(!library) { library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; } if(!library) { NSLog(@"Failed to load bitsandbytes.metallib"); abort(); } return library; } /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) { id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; }*/ // MPSGraph function for quantize extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { id<MTLDevice> device = get_device(); id<MTLLibrary> library = get_library(); static id<MTLFunction> kernel = nil; if(!kernel) { kernel = [library newFunctionWithName:@"quantize"]; if(!kernel) { NSLog(@"Failed to load bitsandbytes.metallib"); abort(); } } NSLog(@"Not implemented"); return nil; }