optimum/quanto/library/extensions/mps/unpack.mm (90 lines of code) (raw):
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "unpack.h"
#include <torch/extension.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// Defines a Metal custom kernel to mask and shift a buffer element-wise.
static char *MASK_AND_SHIFT = R"MPS_MASK&SHIFT(
#include <metal_stdlib>
using namespace metal;
[[host_name("mask_and_rshift")]]
kernel void mask_and_rshift(constant uint8_t* input [[buffer(0)]],
device uint8_t* output [[buffer(1)]],
constant uint8_t& mask [[buffer(2)]],
constant int& shift [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
output[index] = (input[index] & mask) >> shift;
}
)MPS_MASK&SHIFT";
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
torch::Tensor& mask_and_shift(const torch::Tensor& input, torch::Tensor& output, uint8_t mask, int shift) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSError *error = nil;
// Set the number of threads equal to the number of elements within the input tensor.
int num_threads = input.numel();
// Load the custom mask and shift shader.
id<MTLLibrary> library = [device newLibraryWithSource:[NSString stringWithUTF8String:MASK_AND_SHIFT]
options:nil
error:&error];
TORCH_CHECK(library, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
id<MTLFunction> kernel = [library newFunctionWithName:[NSString stringWithUTF8String:"mask_and_rshift"]];
TORCH_CHECK(kernel, "Failed to create function state object for mask_and_rshift");
// Create a compute pipeline state object for the soft shrink kernel.
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:kernel error:&error];
TORCH_CHECK(pso, error.localizedDescription.UTF8String);
// This is required if torch already encoded something in the command buffer
torch::mps::synchronize();
// Get a reference to the command buffer for the MPS stream.
id<MTLCommandBuffer> command_buffer = torch::mps::get_command_buffer();
TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference");
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
dispatch_queue_t serial_queue = torch::mps::get_dispatch_queue();
dispatch_sync(serial_queue, ^(){
// Start a compute pass.
id<MTLComputeCommandEncoder> compute_encoder = [command_buffer computeCommandEncoder];
TORCH_CHECK(compute_encoder, "Failed to create compute command encoder");
// Encode the pipeline state object and its parameters.
[compute_encoder setComputePipelineState:pso];
[compute_encoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0];
[compute_encoder setBuffer:getMTLBufferStorage(output) offset:output.storage_offset() * output.element_size() atIndex:1];
[compute_encoder setBytes:&mask length:sizeof(uint8_t) atIndex:2];
[compute_encoder setBytes:&shift length:sizeof(int) atIndex:3];
MTLSize grid_size = MTLSizeMake(num_threads, 1, 1);
// Calculate a thread group size.
NSUInteger thread_group_size = pso.maxTotalThreadsPerThreadgroup;
if (thread_group_size > num_threads) {
thread_group_size = num_threads;
}
MTLSize mtl_size = MTLSizeMake(thread_group_size, 1, 1);
// Encode the compute command.
[compute_encoder dispatchThreads:grid_size
threadsPerThreadgroup:mtl_size];
[compute_encoder endEncoding];
// Commit the work.
torch::mps::commit();
});
torch::mps::synchronize();
}
return output;
}
torch::Tensor unpack_4bit(const torch::Tensor &input) {
torch::Tensor output = torch::empty_like(input);
mask_and_shift(input, output, 0x0F, 0);
torch::Tensor output1 = torch::empty_like(input);
mask_and_shift(input, output1, 0xF0, 4);
return torch::cat({output, output1}, 0);
}
torch::Tensor unpack_2bit(const torch::Tensor &input) {
torch::Tensor output = torch::empty_like(input);
mask_and_shift(input, output, 0x03, 0);
torch::Tensor output1 = torch::empty_like(input);
mask_and_shift(input, output1, 0x0C, 2);
torch::Tensor output2 = torch::empty_like(input);
mask_and_shift(input, output2, 0x30, 4);
torch::Tensor output3 = torch::empty_like(input);
mask_and_shift(input, output3, 0xC0, 6);
return torch::cat({output, output1, output2, output3}, 0);
}
// C++ op dispatching the Metal unpack operation.
torch::Tensor unpack(const torch::Tensor &input, int bits) {
// Check whether the input tensor resides on the MPS device and whether it's contiguous.
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
// Check the supported data types for soft shrink.
TORCH_CHECK(input.scalar_type() == torch::kUInt8, "Unsupported data type: ", input.scalar_type());
switch(bits) {
case 4:
return unpack_4bit(input);
case 2:
return unpack_2bit(input);
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
}