Guest User

ir generation code

a guest
Nov 17th, 2025
8
0
173 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Kotlin 20.32 KB | None | 0 0
  1. package com.martmists.serialization
  2.  
  3. import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
  4. import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
  5. import org.jetbrains.kotlin.builtins.StandardNames
  6. import org.jetbrains.kotlin.descriptors.ClassKind
  7. import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
  8. import org.jetbrains.kotlin.descriptors.Modality
  9. import org.jetbrains.kotlin.ir.builders.*
  10. import org.jetbrains.kotlin.ir.builders.declarations.addConstructor
  11. import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
  12. import org.jetbrains.kotlin.ir.builders.declarations.buildClass
  13. import org.jetbrains.kotlin.ir.builders.declarations.buildField
  14. import org.jetbrains.kotlin.ir.declarations.*
  15. import org.jetbrains.kotlin.ir.expressions.IrClassReference
  16. import org.jetbrains.kotlin.ir.expressions.IrConst
  17. import org.jetbrains.kotlin.ir.expressions.IrExpression
  18. import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
  19. import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionExpressionImpl
  20. import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
  21. import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
  22. import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl
  23. import org.jetbrains.kotlin.ir.types.*
  24. import org.jetbrains.kotlin.ir.util.*
  25. import org.jetbrains.kotlin.name.ClassId
  26. import org.jetbrains.kotlin.name.FqName
  27. import org.jetbrains.kotlin.name.Name
  28.  
  29. @OptIn(UnsafeDuringIrConstructionAPI::class)
  30. class CodecIRGenerator : IrGenerationExtension {
  31.     override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
  32.         val recordAnnotation = FqName("com.martmists.serialization.Record")
  33.  
  34.         val recordClasses = mutableListOf<Pair<IrClass, IrClass>>()
  35.         moduleFragment.files.forEach { file ->
  36.             file.declarations.filterIsInstance<IrClass>().forEach { klass ->
  37.                 if (!klass.annotations.hasAnnotation(recordAnnotation)) return@forEach
  38.  
  39.                 if (!klass.isData) {
  40.                     pluginContext.reportError(
  41.                         klass,
  42.                         "@Record can only be applied to data classes"
  43.                     )
  44.                     return@forEach
  45.                 }
  46.  
  47.                 val companion = klass.declarations.filterIsInstance<IrClass>().firstOrNull { it.isCompanion } ?: createCompanionObject(klass, pluginContext)
  48.  
  49.                 if (companion.declarations.any { it is IrField && it.name.asString() == "CODEC" }) {
  50.                     pluginContext.reportError(
  51.                         companion,
  52.                         "@Record classes should not have a CODEC field"
  53.                     )
  54.                 }
  55.  
  56.                 addEmptyCodecField(klass, companion, pluginContext)
  57.                 recordClasses += klass to companion
  58.             }
  59.         }
  60.  
  61.         recordClasses.forEach { (klass, companion) ->
  62.             initializeCodecField(klass, companion, pluginContext)
  63.         }
  64.     }
  65.  
  66.     private fun createCompanionObject(
  67.         klass: IrClass,
  68.         pluginContext: IrPluginContext
  69.     ): IrClass {
  70.         val companion = pluginContext.irFactory.buildClass {
  71.             startOffset = klass.startOffset
  72.             endOffset = klass.endOffset
  73.             name = Name.identifier("Companion")
  74.             isCompanion = true
  75.             kind = ClassKind.OBJECT
  76.         }
  77.         companion.createThisReceiverParameter()
  78.         companion.addConstructor {
  79.             isPrimary = true
  80.         }.apply {
  81.             body = pluginContext.irFactory.createExpressionBody(
  82.                 companion.startOffset,
  83.                 companion.endOffset,
  84.                 IrBlockBuilder(pluginContext, Scope(symbol), klass.startOffset, klass.endOffset).buildStatement(
  85.                     klass.startOffset,
  86.                     klass.endOffset
  87.                 ) {
  88.                     irReturnUnit()
  89.                 }
  90.             )
  91.         }
  92.         companion.parent = klass
  93.         klass.declarations.add(companion)
  94.         return companion
  95.     }
  96.  
  97.     private fun addEmptyCodecField(klass: IrClass, companion: IrClass, pluginContext: IrPluginContext) {
  98.         val field = pluginContext.irFactory.buildField {
  99.             startOffset = klass.startOffset
  100.             endOffset = klass.endOffset
  101.             name = Name.identifier("CODEC")
  102.             type = pluginContext.irBuiltIns.anyType
  103.         }
  104.         field.parent = companion
  105.         companion.declarations.add(field)
  106.     }
  107.  
  108.     private fun IrBuilder.getCodec(pluginContext: IrPluginContext, type: IrType): IrExpression {
  109.         val codecLocationAnnotation = FqName("com.martmists.serialization.CodecLocation")
  110.  
  111.         val codecKlass = pluginContext.referenceClass(
  112.             ClassId(
  113.                 FqName("com.mojang.serialization"),
  114.                 Name.identifier("Codec"),
  115.             )
  116.         )!!
  117.  
  118.         return when {
  119.             type == pluginContext.irBuiltIns.intType -> irGetStaticCodec(codecKlass, "INT")
  120.             type == pluginContext.irBuiltIns.floatType -> irGetStaticCodec(codecKlass, "FLOAT")
  121.             type == pluginContext.irBuiltIns.booleanType -> irGetStaticCodec(codecKlass, "BOOLEAN")
  122.             type == pluginContext.irBuiltIns.stringType -> irGetStaticCodec(codecKlass, "STRING")
  123.  
  124.             type.classOrNull == pluginContext.irBuiltIns.listClass -> {
  125.                 val elementType = (type as IrSimpleType).arguments.first().typeOrFail
  126.                 val elementCodec = getCodec(pluginContext, elementType)
  127.                 irCall(codecKlass.owner.functions.first { it.name.asString() == "list" && it.parameters.size == 1 }).apply {
  128.                     arguments[0] = elementCodec
  129.                 }
  130.             }
  131.  
  132.             else -> {
  133.                 val ann = type.annotations.firstOrNull { it.isAnnotationWithEqualFqName(codecLocationAnnotation) }
  134.                 val (typeSymbol, fieldName) = if (ann == null) {
  135.                     val klass = type.classOrNull ?: error("Type has no class: $type")
  136.                     klass to "CODEC"
  137.                 } else {
  138.                     val klass = ann.arguments[0] as IrClassReference
  139.                     val name = ann.arguments[1] as IrConst
  140.                     (klass.symbol as IrClassSymbol) to (name.value as String)
  141.                 }
  142.  
  143.                 val companion = typeSymbol.owner.companionObject()
  144.                 val codecField = (companion ?: typeSymbol.owner).declarations.filterIsInstance<IrProperty>().firstOrNull { it.name.asString() == fieldName }?.backingField
  145.                 ?: error("Type ${typeSymbol.owner.name} does not have a CODEC property")
  146.  
  147.                 val receiver = if (companion != null) irGetObjectValue(companion.defaultType, companion.symbol) else null
  148.  
  149.                 irGetField(receiver, codecField)
  150.             }
  151.         }
  152.     }
  153.  
  154.     private fun IrBuilder.irGetStaticCodec(codecKlass: IrClassSymbol, fieldName: String): IrExpression {
  155.         val field = codecKlass.owner.declarations
  156.             .filterIsInstance<IrProperty>()
  157.             .first { it.name.asString() == fieldName }
  158.             .backingField!!
  159.         return irGetField(null, field)
  160.     }
  161.  
  162.     private fun initializeCodecField(
  163.         klass: IrClass,
  164.         companion: IrClass,
  165.         pluginContext: IrPluginContext,
  166.     ) {
  167.         val ctor = klass.primaryConstructor ?: run {
  168.             pluginContext.reportError(klass, "@Record requires a primary constructor")
  169.             return
  170.         }
  171.  
  172.         val field = companion.fields.first { it.name.asString() == "CODEC" }
  173.  
  174.         val codecBuilderKlass = pluginContext.referenceClass(
  175.             ClassId(
  176.                 FqName("com.mojang.serialization.codecs"),
  177.                 Name.identifier("RecordCodecBuilder")
  178.             )
  179.         )!!
  180.         val codecBuilderInstanceKlass = pluginContext.referenceClass(
  181.             ClassId(
  182.                 FqName("com.mojang.serialization.codecs"),
  183.                 FqName("RecordCodecBuilder.Instance"),
  184.                 false,
  185.             )
  186.         )!!
  187.         val codecKlass = pluginContext.referenceClass(
  188.             ClassId(
  189.                 FqName("com.mojang.serialization"),
  190.                 Name.identifier("Codec"),
  191.             )
  192.         )!!
  193.         val mapCodecKlass = pluginContext.referenceClass(
  194.             ClassId(
  195.                 FqName("com.mojang.serialization"),
  196.                 Name.identifier("MapCodec"),
  197.             )
  198.         )!!
  199.         val appKlass = pluginContext.referenceClass(
  200.             ClassId(
  201.                 FqName("com.mojang.datafixers"),
  202.                 FqName("Products.P${ctor.nonDispatchParameters.size}"),
  203.                 false,
  204.             )
  205.         )!!
  206.  
  207.         val create = codecBuilderKlass.functions.first { it.owner.name.asString() == "create" }
  208.         val group = codecBuilderInstanceKlass.functions.first {
  209.             it.owner.name.asString() == "group" && it.owner.nonDispatchParameters.size == ctor.parameters.size
  210.         }
  211.         val fieldOf = codecKlass.functions.first { it.owner.name.asString() == "fieldOf" }
  212.         val optionalFieldOf = codecKlass.functions.first { it.owner.name.asString() == "optionalFieldOf" }
  213.         val forGetter = mapCodecKlass.functions.first { it.owner.name.asString() == "forGetter" }
  214.         val apply = appKlass.functions.first { it.owner.name.asString() == "apply" }
  215.  
  216.         field.parent = companion
  217.         field.initializer = pluginContext.irFactory.createExpressionBody(
  218.             IrBlockBuilder(pluginContext, Scope(companion.symbol), klass.startOffset, klass.endOffset).buildStatement(
  219.                 klass.startOffset,
  220.                 klass.endOffset
  221.             ) {
  222.                 irCall(create).apply {
  223.                     typeArguments[0] = klass.defaultType
  224.  
  225.                     val codecType = codecKlass.typeWith(klass.defaultType)
  226.                     val lambdaType = pluginContext.referenceClass(StandardNames.getFunctionClassId(1))!!.typeWith(codecBuilderInstanceKlass.typeWith(klass.defaultType), codecType)
  227.                     val lambda = pluginContext.irFactory.createSimpleFunction(
  228.                         klass.startOffset,
  229.                         klass.endOffset,
  230.                         IrDeclarationOrigin.DEFINED,
  231.                         Name.special("<anonymous>"),
  232.                         DescriptorVisibilities.LOCAL,
  233.                         isInline = false,
  234.                         isExpect = false,
  235.                         returnType = codecType,
  236.                         modality = Modality.FINAL,
  237.                         symbol = IrSimpleFunctionSymbolImpl(),
  238.                         isTailrec = false,
  239.                         isSuspend = false,
  240.                         isOperator = false,
  241.                         isInfix = false,
  242.                     ).apply {
  243.                         origin = IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA
  244.                         parent = companion
  245.  
  246.                         val param = addValueParameter {
  247.                             name = Name.identifier("it")
  248.                             type = codecBuilderInstanceKlass.typeWith(klass.defaultType)
  249.                         }
  250.  
  251.                         body = IrBlockBuilder(pluginContext, Scope(symbol), startOffset, endOffset).irBlockBody {
  252.                             val builder = irGet(param)
  253.  
  254.                             val product = irCall(group).apply {
  255.                                 dispatchReceiver = builder
  256.                                 ctor.nonDispatchParameters.forEachIndexed { index, arg ->
  257.                                     typeArguments[index] = arg.type
  258.                                 }
  259.                                 arguments.addAll(1,
  260.                                     ctor.nonDispatchParameters.map { arg ->
  261.                                         val codec = getCodec(pluginContext, arg.type)
  262.  
  263.                                         val field = if (arg.hasDefaultValue()) {
  264.                                             irCall(optionalFieldOf).apply {
  265.                                                 dispatchReceiver = codec
  266.                                                 arguments.addAll(1,
  267.                                                     listOf(
  268.                                                         irString(arg.name.asString()),
  269.                                                         arg.defaultValue!!.expression,
  270.                                                     )
  271.                                                 )
  272.                                             }
  273.                                         } else {
  274.                                             irCall(fieldOf).apply {
  275.                                                 dispatchReceiver = codec
  276.                                                 arguments.addAll(1,
  277.                                                     listOf(
  278.                                                         irString(arg.name.asString()),
  279.                                                     )
  280.                                                 )
  281.                                             }
  282.                                         }
  283.  
  284.                                         irCall(forGetter).apply {
  285.                                             typeArguments[0] = klass.defaultType
  286.                                             dispatchReceiver = field
  287.                                             val prop = klass.properties.first { it.name == arg.name }
  288.  
  289.                                             val getterReturnType = prop.getter!!.returnType
  290.                                             val functionTypeForGetter = pluginContext.referenceClass(StandardNames.getFunctionClassId(1))!!
  291.                                                 .typeWith(klass.defaultType, getterReturnType)
  292.  
  293.                                             val getterLambda = pluginContext.irFactory.createSimpleFunction(
  294.                                                 klass.startOffset, klass.endOffset,
  295.                                                 IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA,
  296.                                                 Name.special("<get_${prop.name.asString()}>_lambda"),
  297.                                                 DescriptorVisibilities.LOCAL,
  298.                                                 isInline = false,
  299.                                                 isExpect = false,
  300.                                                 returnType = getterReturnType,
  301.                                                 modality = Modality.FINAL,
  302.                                                 symbol = IrSimpleFunctionSymbolImpl(),
  303.                                                 isTailrec = false,
  304.                                                 isSuspend = false,
  305.                                                 isOperator = false,
  306.                                                 isInfix = false,
  307.                                             ).apply {
  308.                                                 parent = companion
  309.                                                 val p = addValueParameter {
  310.                                                     name = Name.identifier("receiver")
  311.                                                     type = klass.defaultType
  312.                                                 }
  313.                                                 body = IrBlockBuilder(pluginContext, Scope(symbol), startOffset, endOffset).irBlockBody {
  314.                                                     val getterCall = irCall(prop.getter!!).apply {
  315.                                                         dispatchReceiver = irGet(p)
  316.                                                     }
  317.                                                     +irReturn(getterCall)
  318.                                                 }
  319.                                             }
  320.  
  321.                                             val getterFunctionExpression = IrFunctionExpressionImpl(
  322.                                                 klass.startOffset, klass.endOffset,
  323.                                                 functionTypeForGetter,
  324.                                                 getterLambda,
  325.                                                 IrStatementOrigin.LAMBDA
  326.                                             )
  327.  
  328.                                             arguments.addAll(1,
  329.                                                 listOf(
  330.                                                     getterFunctionExpression
  331.                                                 )
  332.                                             )
  333.                                         }
  334.                                     }
  335.                                 )
  336.                             }
  337.  
  338.                             val result = irCall(apply).apply {
  339.                                 typeArguments[0] = klass.defaultType
  340.                                 dispatchReceiver = product
  341.  
  342.                                 val ctorParamTypes = ctor.nonDispatchParameters.map { it.type }
  343.                                 val ctorReturnType = klass.defaultType
  344.                                 val funcClassId = StandardNames.getFunctionClassId(ctorParamTypes.size)
  345.                                 val ctorFunctionType = pluginContext.referenceClass(funcClassId)!!.typeWith(*(ctorParamTypes + ctorReturnType).toTypedArray())
  346.  
  347.                                 val ctorLambda = pluginContext.irFactory.createSimpleFunction(
  348.                                     klass.startOffset, klass.endOffset,
  349.                                     IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA,
  350.                                     Name.special("<ctor_lambda>"),
  351.                                     DescriptorVisibilities.LOCAL,
  352.                                     isInline = false,
  353.                                     isExpect = false,
  354.                                     returnType = ctorReturnType,
  355.                                     modality = Modality.FINAL,
  356.                                     symbol = IrSimpleFunctionSymbolImpl(),
  357.                                     isTailrec = false,
  358.                                     isSuspend = false,
  359.                                     isOperator = false,
  360.                                     isInfix = false,
  361.                                 ).apply {
  362.                                     parent = companion
  363.                                     val params = ctor.nonDispatchParameters.mapIndexed { i, p ->
  364.                                         addValueParameter {
  365.                                             name = p.name
  366.                                             type = p.type
  367.                                         }
  368.                                     }
  369.  
  370.                                     body = IrBlockBuilder(pluginContext, Scope(symbol), startOffset, endOffset).irBlockBody {
  371.                                         val ctorCall = irCallConstructor(ctor.symbol, emptyList()).apply {
  372.                                             params.forEachIndexed { idx, param ->
  373.                                                 arguments[idx] = irGet(param)
  374.                                             }
  375.                                         }
  376.                                         +irReturn(ctorCall)
  377.                                     }
  378.                                 }
  379.  
  380.                                 val ctorFunctionExpression = IrFunctionExpressionImpl(
  381.                                     klass.startOffset, klass.endOffset,
  382.                                     ctorFunctionType,
  383.                                     ctorLambda,
  384.                                     IrStatementOrigin.LAMBDA
  385.                                 )
  386.  
  387.                                 arguments.addAll(1,
  388.                                     listOf(
  389.                                         builder,
  390.                                         ctorFunctionExpression
  391.                                     )
  392.                                 )
  393.                             }
  394.  
  395.                             +irReturn(result)
  396.                         }
  397.                     }
  398.  
  399.                     arguments[0] = IrFunctionExpressionImpl(companion.startOffset, companion.endOffset, lambdaType, lambda, IrStatementOrigin.LAMBDA)
  400.                 }
  401.             }
  402.         )
  403.     }
  404.  
  405.     private fun IrPluginContext.reportError(target: IrDeclaration, message: String) {
  406.         // TODO: diagnosticReporter.at(target).report(...)
  407.         throw Exception(message)
  408.     }
  409. }
  410.  
Add Comment
Please, Sign In to add comment