in thrifty-kotlin-codegen/src/main/kotlin/com/microsoft/thrifty/kgen/KotlinCodeGenerator.kt [1648:1910]
fun recursivelyRenderConstValue(block: CodeBlock.Builder, type: ThriftType, value: ConstValueElement) {
type.accept(object : ThriftType.Visitor<Unit> {
override fun visitVoid(voidType: BuiltinType) {
error("Can't have void as a constant")
}
override fun visitBool(boolType: BuiltinType) {
if (value is IdentifierValueElement && value.value in listOf("true", "false")) {
block.add("%L", value.value)
} else if (value is IntValueElement) {
block.add("%L", value.value != 0L)
} else {
constOrError("Invalid boolean constant")
}
}
override fun visitByte(byteType: BuiltinType) {
if (value is IntValueElement) {
block.add("%L", value.value)
} else {
constOrError("Invalid byte constant")
}
}
override fun visitI16(i16Type: BuiltinType) {
if (value is IntValueElement) {
block.add("%L", value.value)
} else {
constOrError("Invalid I16 constant")
}
}
override fun visitI32(i32Type: BuiltinType) {
if (value is IntValueElement) {
block.add("%L", value.value)
} else {
constOrError("Invalid I32 constant")
}
}
override fun visitI64(i64Type: BuiltinType) {
if (value is IntValueElement) {
block.add("%L", value.value)
} else {
constOrError("Invalid I64 constant")
}
}
override fun visitDouble(doubleType: BuiltinType) {
when (value) {
is IntValueElement -> block.add("%L.toDouble()", value.value)
is DoubleValueElement -> block.add("%L", value.value)
else -> constOrError("Invalid double constant")
}
}
override fun visitString(stringType: BuiltinType) {
if (value is LiteralValueElement) {
block.add("%S", value.value)
} else {
constOrError("Invalid string constant")
}
}
override fun visitBinary(binaryType: BuiltinType) {
// TODO: Implement support for binary constants in the ANTLR grammar
if (value is LiteralValueElement) {
block.add("%T.decodeHex(%S)", ByteString::class, value.value)
} else {
constOrError("Invalid binary constant")
}
}
override fun visitEnum(enumType: EnumType) {
val member = try {
when (value) {
// Enum references may or may not be scoped with their typename; either way, we must remove
// the type reference to get the member name on its own.
is IdentifierValueElement -> enumType.findMemberByName(value.value.split(".").last())
is IntValueElement -> enumType.findMemberById(value.value.toInt())
else -> throw AssertionError("Value kind $value is not possibly an enum")
}
} catch (e: NoSuchElementException) {
null
}
if (member != null) {
block.add("${enumType.typeName}.%L", member.name)
} else {
constOrError("Invalid enum constant")
}
}
override fun visitList(listType: ListType) {
visitCollection(
listType.elementType,
listClassName,
"listOf",
"emptyList",
"Invalid list constant")
}
override fun visitSet(setType: SetType) {
visitCollection(
setType.elementType,
setClassName,
"setOf",
"emptySet",
"Invalid set constant")
}
private fun visitCollection(
elementType: ThriftType,
customClassName: ClassName?,
factoryMethod: String,
emptyFactory: String,
error: String) {
if (value is ListValueElement) {
if (value.value.isEmpty()) {
block.add("%L()", emptyFactory)
return
}
if (customClassName != null) {
val concreteName = customClassName.parameterizedBy(elementType.typeName)
emitCustomCollection(elementType, value, concreteName)
} else {
emitDefaultCollection(elementType, value, factoryMethod)
}
} else {
constOrError(error)
}
}
private fun emitCustomCollection(elementType: ThriftType, value: ListValueElement, collectionType: TypeName) {
block.add("%T(%L).apply·{⇥\n", collectionType, value.value.size)
for (element in value.value) {
block.add("add(")
recursivelyRenderConstValue(block, elementType, element)
block.add(")\n")
}
block.add("⇤}")
}
private fun emitDefaultCollection(elementType: ThriftType, value: ListValueElement, factoryMethod: String) {
block.add("$factoryMethod(⇥")
for ((n, elementValue) in value.value.withIndex()) {
if (n > 0) {
block.add(", ")
}
recursivelyRenderConstValue(block, elementType, elementValue)
}
block.add("⇤)")
}
override fun visitMap(mapType: MapType) {
val keyType = mapType.keyType
val valueType = mapType.valueType
if (value is MapValueElement) {
if (value.value.isEmpty()) {
block.add("emptyMap()")
return
}
val customType = mapClassName
if (customType != null) {
val concreteType = customType
.parameterizedBy(keyType.typeName, valueType.typeName)
emitCustomMap(mapType, value, concreteType)
} else {
emitDefaultMap(mapType, value)
}
} else {
constOrError("Invalid map constant")
}
}
private fun emitDefaultMap(mapType: MapType, value: MapValueElement) {
val keyType = mapType.keyType
val valueType = mapType.valueType
block.add("mapOf(⇥")
var n = 0
for ((k, v) in value.value) {
if (n++ > 0) {
block.add(",·")
}
recursivelyRenderConstValue(block, keyType, k)
block.add("·to·")
recursivelyRenderConstValue(block, valueType, v)
}
block.add("⇤)")
}
private fun emitCustomMap(mapType: MapType, value: MapValueElement, mapTypeName: TypeName) {
val keyType = mapType.keyType
val valueType = mapType.valueType
block.add("%T(%L).apply·{\n⇥", mapTypeName, value.value.size)
for ((k, v) in value.value) {
block.add("put(")
recursivelyRenderConstValue(block, keyType, k)
block.add(", ")
recursivelyRenderConstValue(block, valueType, v)
block.add(")\n")
}
block.add("⇤}")
}
override fun visitStruct(structType: StructType) {
TODO("not implemented")
}
override fun visitTypedef(typedefType: TypedefType) {
typedefType.trueType.accept(this)
}
override fun visitService(serviceType: ServiceType) {
throw AssertionError("Cannot have a const value of a service type, wat r u doing")
}
private fun constOrError(error: String) {
val message = "$error: $value at ${value.location}"
if (value !is IdentifierValueElement) {
throw IllegalStateException(message)
}
val name: String
val expectedProgram: String?
val text = value.value
val ix = text.indexOf(".")
if (ix != -1) {
expectedProgram = text.substring(0, ix)
name = text.substring(ix + 1)
} else {
expectedProgram = null
name = text
}
val c = schema.constants.asSequence()
.firstOrNull {
it.name == name
&& it.type.trueType == type.trueType
&& (expectedProgram == null || expectedProgram == it.location.programName)
} ?: throw IllegalStateException(message)
val packageName = c.getNamespaceFor(NamespaceScope.KOTLIN, NamespaceScope.JAVA, NamespaceScope.ALL)
?: throw IllegalStateException("No JVM namespace found for ${c.name} at ${c.location}")
block.add("$packageName.$name")
}
})
}