Sources/TensorUtils/Weights.swift (77 lines of code) (raw):

import CoreML public struct Weights { enum WeightsError: LocalizedError { case notSupported(message: String) case invalidFile public var errorDescription: String? { switch self { case let .notSupported(message): String(localized: "The weight format '\(message)' is not supported by this application.", comment: "Error when weight format is not supported") case .invalidFile: String(localized: "The weights file is invalid or corrupted.", comment: "Error when weight file is invalid") } } } private let dictionary: [String: MLMultiArray] init(_ dictionary: [String: MLMultiArray]) { self.dictionary = dictionary } subscript(key: String) -> MLMultiArray? { dictionary[key] } public static func from(fileURL: URL) throws -> Weights { guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: "gguf") case ([0x93, 0x4E, 0x55, 0x4D], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") default: return try Safetensor.from(data: data) } } } struct Safetensor { typealias Error = Weights.WeightsError struct Header { struct Offset: Decodable { let dataOffsets: [Int]? let dtype: String? let shape: [Int]? /// Unsupported: "I8", "U8", "I16", "U16", "BF16" var dataType: MLMultiArrayDataType? { get throws { switch dtype { case "I32", "U32": .int32 case "F16": .float16 case "F32": .float32 case "F64", "U64": .float64 default: throw Error.notSupported(message: "\(dtype ?? "empty")") } } } } static func from(data: Data) throws -> [String: Offset?] { let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase return try decoder.decode([String: Offset?].self, from: data) } } static func from(data: Data) throws -> Weights { let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes { $0.load(as: Int.self) } guard headerSize < data.count else { throw Error.invalidFile } let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) var dict = [String: MLMultiArray]() for (key, point) in header { guard let offsets = point?.dataOffsets, offsets.count == 2, let shape = point?.shape as? [NSNumber], let dType = try point?.dataType else { continue } let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) } let start = 8 + offsets[0] + headerSize let end = 8 + offsets[1] + headerSize let tensorData = data.subdata(in: start..<end) as NSData let ptr = UnsafeMutableRawPointer(mutating: tensorData.bytes) dict[key] = try MLMultiArray(dataPointer: ptr, shape: shape, dataType: dType, strides: strides) } return Weights(dict) } }