SemanticSegmentationSample/Common/SemanticMapToImage.swift (79 lines of code) (raw):

import CoreML import CoreImage import CoreImage.CIFilterBuiltins import Metal class SemanticMapToImage { let device: MTLDevice let commandQueue: MTLCommandQueue let pipelineState: MTLComputePipelineState public static let shared: SemanticMapToImage? = SemanticMapToImage() enum MetalConversionError : Error { case commandBufferError case encoderError case coreImageError } public init?() { guard let theMetalDevice = MTLCreateSystemDefaultDevice() else { return nil } device = theMetalDevice guard let cmdQueue = theMetalDevice.makeCommandQueue() else { return nil } commandQueue = cmdQueue guard let library = device.makeDefaultLibrary() else { return nil } guard let makeContiguousKernel = library.makeFunction(name: "SemanticMapToColor") else { return nil } guard let pipelineState = try? device.makeComputePipelineState(function: makeContiguousKernel) else { return nil } self.pipelineState = pipelineState } public func mapToImage(semanticMap: MLShapedArray<Int32>, numClasses: Int) throws -> CIImage { guard let commandBuffer = commandQueue.makeCommandBuffer() else { throw MetalConversionError.commandBufferError } guard let outputTexture = encodeComputePipeline(commandBuffer: commandBuffer, semanticMap: semanticMap, numClasses: numClasses) else { throw MetalConversionError.encoderError } commandBuffer.commit() commandBuffer.waitUntilCompleted() guard let image = CIImage(mtlTexture: outputTexture, options: [.colorSpace: CGColorSpaceCreateDeviceRGB()]) else { throw MetalConversionError.coreImageError } return image .transformed(by: CGAffineTransform(scaleX: 1, y: -1)) .transformed(by: CGAffineTransform(translationX: 0, y: image.extent.height)) } func encodeComputePipeline(commandBuffer: MTLCommandBuffer, semanticMap: MLShapedArray<Int32>, numClasses: Int) -> MTLTexture? { guard let commandEncoder = commandBuffer.makeComputeCommandEncoder() else { return nil } commandEncoder.setComputePipelineState(pipelineState) let (width, height) = (semanticMap.shape[0], semanticMap.shape[1]) guard let outputTexture = makeTexture(width: width, height: height, pixelFormat: .bgra8Unorm) else { return nil } commandEncoder.setTexture(sourceTexture(semanticMap), index: 0) commandEncoder.setTexture(outputTexture, index: 1) // FIXME: hardcoded for now var classCount = numClasses commandEncoder.setBytes(&classCount, length: MemoryLayout<Int32>.size, index: 0) let w = pipelineState.threadExecutionWidth let h = pipelineState.maxTotalThreadsPerThreadgroup / w let threadsPerThreadgroup = MTLSizeMake(w, h, 1) let threadsPerGrid = MTLSize(width: outputTexture.width, height: outputTexture.height, depth: 1) commandEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) commandEncoder.endEncoding() return outputTexture } func sourceTexture(_ semanticMap: MLShapedArray<Int32>) -> MTLTexture? { let (width, height) = (semanticMap.shape[0], semanticMap.shape[1]) let texture = makeTexture(width: width, height: height) let region = MTLRegionMake2D(0, 0, width, height) let array = MLMultiArray(semanticMap) texture?.replace(region: region, mipmapLevel: 0, withBytes: array.dataPointer, bytesPerRow: width * MemoryLayout<Int32>.stride) return texture } func makeTexture(width: Int, height: Int, pixelFormat: MTLPixelFormat = .r32Uint) -> MTLTexture? { let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: pixelFormat, width: width, height: height, mipmapped: false) textureDescriptor.usage = [.shaderRead, .shaderWrite] return device.makeTexture(descriptor: textureDescriptor) } }