Sources/SwiftBuffet/Generator.swift (360 lines of code) (raw):
import Foundation
/// Generates Swift code from protocol buffer messages and enums.
///
/// - Parameters:
/// - messages: An array of `ProtoMessage` representing protocol buffer messages.
/// - enums: An array of `ProtoEnum` representing protocol buffer enums.
/// - Returns: A string containing the generated Swift code.
func generateSwiftCode(
from messages: [ProtoMessage],
enums: [ProtoEnum],
with swiftPrefix: String,
includeProto: Bool,
includeLocalIDFor localIDMessages: [String]?,
includeBackingData: Bool,
with protoPrefix: String
) -> String {
var output = "import Foundation\n\n"
write(
messages,
to: &output,
with: swiftPrefix,
includeProto: includeProto,
includeLocalIDFor: localIDMessages,
includeBackingData: includeBackingData,
with: protoPrefix
)
write(
enums,
to: &output,
with: swiftPrefix,
includeProto: includeProto,
with: protoPrefix
)
return output
}
/// Writes the Swift code for protocol buffer enums.
///
/// - Parameters:
/// - enums: An array of `ProtoEnum` to be written.
/// - output: A mutable string where the generated code will be appended.
internal func write(
_ enums: [ProtoEnum],
to output: inout String,
with swiftPrefix: String,
includeProto: Bool,
with protoPrefix: String
) {
if enums.isEmpty == false {
output += "// MARK: - Enums\n"
}
for protoEnum in enums.sorted(by: { $0.name < $1.name }) {
let strippedCases = stripCommonPrefix(from: protoEnum.cases)
let pair = zip(
strippedCases.map(\.name),
protoEnum.cases.map(\.value)
)
if let parent = protoEnum.parentName {
output += "extension \(swiftPrefix)\(parent) {\n"
output += " "
}
output += "public enum \(swiftPrefix)\(protoEnum.name): Int, CaseIterable, Hashable, Equatable, Sendable {\n"
for (caseName, caseValue) in pair {
if protoEnum.parentName != nil {
output += " "
}
output += " case \(caseName) = \(caseValue)\n"
}
if includeProto {
writeEnumProtoInit(
for: protoEnum,
to: &output,
with: protoPrefix
)
}
if protoEnum.parentName != nil {
output += " }\n"
}
output += "}\n\n"
}
}
/// Writes the initializer for a protocol buffer enum.
///
/// - Parameters:
/// - protoEnum: The `ProtoEnum` to write the initializer for.
/// - output: A mutable string where the generated code will be appended.
internal func writeEnumProtoInit(for protoEnum: ProtoEnum, to output: inout String, with protoPrefix: String) {
output += "\n"
let padding = if protoEnum.parentName != nil {
" "
} else {
""
}
output += padding + " internal init?(proto: \(protoPrefix)\(protoEnum.fullName)) {\n"
output += padding + " self.init(rawValue: proto.rawValue)\n"
output += padding + " }\n"
}
/// Writes the Swift code for protocol buffer messages.
///
/// - Parameters:
/// - messages: An array of `ProtoMessage` to be written.
/// - output: A mutable string where the generated code will be appended.
/// - Returns: A boolean indicating whether a TimeInterval helper is needed.
internal func write(
_ messages: [ProtoMessage],
to output: inout String,
with swiftPrefix: String,
includeProto: Bool,
includeLocalIDFor messageNames: [String]?,
includeBackingData: Bool,
with protoPrefix: String
) {
if messages.isEmpty == false {
output += "// MARK: - Structs\n"
}
for message in messages.sorted(by: { $0.name < $1.name }) {
output += "public struct \(swiftPrefix)\(message.name): Hashable, Equatable, Sendable {\n"
writeProperties(
for: message,
includeLocalID: messageNames?.contains(message.name) ?? false,
includeBackingData: includeBackingData,
to: &output
)
writeBasicInit(
for: message,
to: &output
)
if includeProto {
writeMessageProtoInit(
for: message,
includeBackingData: includeBackingData,
to: &output,
with: protoPrefix
)
}
output += "}\n\n"
}
}
/// Adds properties to a message struct.
///
/// - Parameters:
/// - message: The `ProtoMessage` to add properties for.
/// - output: A mutable string where the generated code will be appended.
/// - Returns: A boolean indicating whether the message has a TimeInterval property.
internal func writeProperties(
for message: ProtoMessage,
includeLocalID: Bool,
includeBackingData: Bool,
to output: inout String
) {
for field in message.fields {
if let comment = field.comment {
output += "\n"
output += comment
.replacingOccurrences(of: "/**", with: "")
.replacingOccurrences(of: "*/", with: "")
.replacingOccurrences(of: "*", with: "")
.split(separator: "\n")
.map { $0.trimmingCharacters(in: .whitespaces) }
.filter { !$0.isEmpty }
.map { String(" // \($0)") }
.joined(separator: "\n")
output += "\n"
}
if field.isDeprecated {
output += #"/// This property has been marked as **deprecated** in the proto file"#
output += "\n"
}
output += " public let \(field.caseCorrectName): \(field.caseCorrectedType)\n"
}
if includeLocalID {
output += " public let _localID = UUID()\n"
}
if includeBackingData {
output += " public private(set) var _backingData: Data?\n"
}
}
/// Adds a basic initializer to a message struct.
///
/// - Parameters:
/// - message: The `ProtoMessage` to add the initializer for.
/// - output: A mutable string where the generated code will be appended.
internal func writeBasicInit(for message: ProtoMessage, to output: inout String) {
var fields = message.fields
let last = fields.popLast()
output += "\n public init(\n"
if last == nil {
output += ") {\n"
} else {
for field in fields {
output += " \(field.caseCorrectName): \(field.caseCorrectedType)"
output += ",\n"
}
if let field = last {
output += " \(field.caseCorrectName): \(field.caseCorrectedType)"
output += "\n"
}
output += " ) {\n"
}
for field in message.fields {
output += " self.\(field.caseCorrectName) = \(field.caseCorrectName)\n"
}
output += " }\n"
}
/// Writes the initializer for a protocol buffer message.
///
/// - Parameters:
/// - message: The `ProtoMessage` to write the initializer for.
/// - output: A mutable string where the generated code will be appended.
internal func writeMessageProtoInit(
for message: ProtoMessage,
includeBackingData: Bool,
to output: inout String,
with protoPrefix: String
) {
output += "\n public init?(data: Data) {\n"
output += " if let proto = try? \(protoPrefix)\(message.name)(serializedBytes: data) {\n"
output += " self.init(proto: proto)\n"
if includeBackingData {
output += " self._backingData = data\n"
}
output += " } else {\n"
output += " return nil\n"
output += " }\n"
output += " }\n\n"
output += " internal init?(proto: \(protoPrefix)\(message.name)) {\n"
for field in message.fields {
if field.isOptional {
output += " if proto.has\(field.caseCorrectProtoName.capitalizingFirstLetter()) {\n"
output += " " // additional padding
}
if field.isRepeated {
output += " self.\(field.caseCorrectName) = proto.\(field.caseCorrectProtoName).compactMap { "
if field.isPrimitiveType || field.type.contains("int") {
output += "\(field.caseCorrectedBaseType)($0)"
} else {
output += "\(field.caseCorrectedBaseType)(proto: $0)"
}
output += " }\n"
} else if field.isMap {
output += " self.\(field.caseCorrectName) = proto.\(field.caseCorrectProtoName).reduce(into: \(field.caseCorrectedType)()) { $0[$1.key] = $1.value }\n"
} else if field.caseCorrectedBaseType == "TimeInterval" {
output += " self.\(field.caseCorrectName) = proto.\(field.caseCorrectProtoName).timeInterval\n"
} else if field.caseCorrectedBaseType == "Date" {
output += " self.\(field.caseCorrectName) = proto.\(field.caseCorrectProtoName).date\n"
} else if field.isURL {
if field.isOptional {
output += " self.\(field.caseCorrectName) = URL(string: proto.\(field.caseCorrectName))\n"
} else {
output += " if let \(field.caseCorrectName) = URL(string: proto.\(field.caseCorrectName)) {\n"
output += " self.\(field.caseCorrectName) = \(field.caseCorrectName)\n"
output += " } else {\n"
output += " return nil\n"
output += " }\n"
}
}else if field.type.contains("int") {
output += " self.\(field.caseCorrectName) = Int(exactly: proto.\(field.caseCorrectName))!\n"
} else if field.isPrimitiveType {
output += " self.\(field.caseCorrectName) = proto.\(field.caseCorrectProtoName)\n"
} else if field.isOptional == false {
output += " if let \(field.caseCorrectName) = \(field.caseCorrectedBaseType)(proto: proto.\(field.caseCorrectProtoName)) {\n"
output += " self.\(field.caseCorrectName) = \(field.caseCorrectProtoName)\n"
output += " } else {\n"
output += " return nil\n"
output += " }\n"
} else {
output += " self.\(field.caseCorrectName) = \(field.caseCorrectedBaseType)(proto: proto.\(field.caseCorrectProtoName))\n"
}
if field.isOptional {
output += " } else {\n"
output += " self.\(field.caseCorrectName)"
writeDefaultValue(for: field, to: &output)
output += "\n"
output += " }\n"
}
}
output += " }\n"
}
internal func writeCodingKeys(for message: ProtoMessage, to output: inout String) {
output += "\n enum CodingKeys: String, CodingKey {\n"
for field in message.fields {
output += " case \(field.caseCorrectName) = \"\(snakeToCamelCase(field.name))\"\n"
}
output += " }\n"
}
/// Writes the custom initializer and encoder for messages with TimeInterval fields.
///
/// - Parameters:
/// - message: The `ProtoMessage` to write the custom initializer and encoder for.
/// - output: A mutable string where the generated code will be appended.
internal func writeCodableInit(for message: ProtoMessage, to output: inout String) {
output += "\n"
output += " public init(from decoder: Decoder) throws {\n"
output += " let container = try decoder.container(keyedBy: CodingKeys.self)\n"
for field in message.fields {
if field.isRepeated {
output += " self.\(field.caseCorrectName) = try container.decodeIfPresent(\(field.caseCorrectedType).self, forKey: .\(field.caseCorrectName)) ?? []\n"
} else if field.isMap {
output += " self.\(field.caseCorrectName) = try container.decodeIfPresent(\(field.caseCorrectedType).self, forKey: .\(field.caseCorrectName)) ?? [:]\n"
} else if field.caseCorrectedType == "TimeInterval" {
output += " if let \(field.caseCorrectName)String = try container.decodeIfPresent(String.self, forKey: .\(field.caseCorrectName)) {\n"
output += " self.\(field.caseCorrectName) = TimeInterval(from: \(field.caseCorrectName)String) ?? 0\n"
output += " } else {\n"
if field.isOptional {
output += " self.\(field.caseCorrectName) = nil\n"
} else {
output += " self.\(field.caseCorrectName) = 0\n"
}
output += " }\n"
} else if field.caseCorrectedType.contains("Date") {
output += " if let \(field.caseCorrectName)String = try container.decodeIfPresent(String.self, forKey: .\(field.caseCorrectName)) {\n"
output += " self.\(field.caseCorrectName) = dateFormatter.date(from: \(field.caseCorrectName)String)\n"
output += " } else {\n"
if field.isOptional {
output += " self.\(field.caseCorrectName) = nil\n"
} else {
output += " self.\(field.caseCorrectName) = Date()\n"
}
output += " }\n"
} else if field.isOptional {
output += " self.\(field.caseCorrectName) = try container.decodeIfPresent(\(field.caseCorrectedType.replacingOccurrences(of: "?", with: "")).self, forKey: .\(field.caseCorrectName))\n"
} else if field.type == "bool" {
output += " self.\(field.caseCorrectName) = try container.decodeIfPresent(\(field.caseCorrectedType).self, forKey: .\(field.caseCorrectName)) ?? false\n"
} else {
output += " self.\(field.caseCorrectName) = try container.decode(\(field.caseCorrectedType).self, forKey: .\(field.caseCorrectName))\n"
}
}
output += " }\n\n"
output += " public func encode(to encoder: Encoder) throws {\n"
output += " var container = encoder.container(keyedBy: CodingKeys.self)\n"
for field in message.fields {
switch field.caseCorrectedType {
case "TimeInterval":
output += " let \(field.caseCorrectName)String = String(self.\(field.caseCorrectName)) + \"s\"\n"
output += " try container.encode(\(field.caseCorrectName)String, forKey: .\(field.caseCorrectName))\n"
default:
output += " try container.encode(self.\(field.caseCorrectName), forKey: .\(field.caseCorrectName))\n"
}
}
output += " }\n"
}
/// Writes the custom initializer and encoder for messages with TimeInterval fields.
///
/// - Parameters:
/// - protoEnum: The `ProtoEnum` to write the custom initializer and encoder for.
/// - output: A mutable string where the generated code will be appended.
internal func writeCodableInit(for protoEnum: ProtoEnum, to output: inout String, with swiftPrefix: String) {
let strippedCases = stripCommonPrefix(from: protoEnum.cases)
let pair = zip(
strippedCases.map(\.name),
protoEnum.cases.map(\.name)
)
output += "\n"
output += " public init(from decoder: Decoder) throws {\n"
output += " let container = try decoder.singleValueContainer()\n\n"
output += " if let stringValue = try? container.decode(String.self) {\n"
output += " // Convert string to enum\n"
output += " switch stringValue {\n"
for (caseName, stringName) in pair {
output += " case \"\(stringName)\":\n"
output += " self = .\(caseName)\n"
}
output += " default:\n"
output += " self = .unspecified\n"
output += " }\n"
output += " } else if let intValue = try? container.decode(Int.self) {\n"
output += " // Convert integer to enum\n"
output += " self = \(swiftPrefix)\(protoEnum.name)(rawValue: intValue) ?? .unspecified\n"
output += " } else {\n"
output += " throw DecodingError.dataCorruptedError(in: container, debugDescription: \"Invalid value for MyEnum\")\n"
output += " }\n"
output += " }\n\n"
output += " public func encode(to encoder: Encoder) throws {\n"
output += " var container = encoder.singleValueContainer()\n"
output += " switch self {\n"
for (caseName, stringName) in pair {
output += " case .\(caseName):\n"
output += " try container.encode(\"\(stringName)\")\n"
}
output += " }\n"
output += " }\n"
}
/// Writes the TimeInterval helper extension if needed.
///
/// - Parameters:
/// - messages: An array of `ProtoMessage` to check for TimeInterval fields.
/// - output: A mutable string where the generated code will be appended.
internal func writeTimeIntervalHelper(to output: inout String) {
if let fileContents = readFileContents(filename: "TimeInterval+String.swift") {
output += "// MARK: - TimeInterval Extension\n"
output += fileContents
.replacingOccurrences(of: "import Foundation", with: "")
.trimmingCharacters(in: .whitespacesAndNewlines)
} else {
output += "// File not found."
}
}
internal func writeDateFormatter(to output: inout String) {
output += "\n\n"
output += "var dateFormatter: ISO8601DateFormatter {\n"
output += " let formatter = ISO8601DateFormatter()\n"
output += " formatter.formatOptions = [.withFullDate, .withFullTime, .withTimeZone]\n"
output += " return formatter\n"
output += "}\n"
}
func writeDefaultValue(for field: ProtoField, to output: inout String) {
if field.isRepeated {
output += " = []"
} else if field.type == "bool" {
output += " = false"
} else if field.isOptional {
output += " = nil"
}
}