DizzyJump

Untitled

Nov 6th, 2025
275
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 13.56 KB | None | 0 0
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Collections.Immutable;
  4. using System.Linq;
  5. using System.Text;
  6. using Microsoft.CodeAnalysis;
  7. using Microsoft.CodeAnalysis.CSharp;
  8. using Microsoft.CodeAnalysis.CSharp.Syntax;
  9. using Microsoft.CodeAnalysis.Text;
  10. using SourceGenerators.Utils;
  11.  
  12. namespace SourceGenerators;
  13.  
  14. [Generator]
  15. public class VContainerAutoRegistrationGenerator : IIncrementalGenerator
  16. {
  17.     private const string Attribute = "CodeBase.CompositionRoot.Helpers.VContainerAutoRegistrationAttribute";
  18.  
  19.     private static Dictionary<string, (TargetLifetime, RegisterType)> preparedInterfaces = new()
  20.     {
  21.         {"Scellecs.Morpeh.IInitializer", (TargetLifetime.Singleton, RegisterType.Self)},
  22.     };
  23.    
  24.     public void Initialize(IncrementalGeneratorInitializationContext context)
  25.     {
  26.         // Шаг 1: Собираем все классы-кандидаты
  27.         var classProvider = context.SyntaxProvider
  28.             .CreateSyntaxProvider(
  29.                 predicate: static (s, _) => IsClassCandidate(s),
  30.                 transform: static (ctx, _) => GetClassSymbol(ctx))
  31.             .Where(static m => m is not null);
  32.  
  33.         // Шаг 2: Комбинируем классы с информацией о компиляции
  34.         var combinedProvider = classProvider
  35.             .Combine(context.CompilationProvider);
  36.  
  37.         // Шаг 3: Обрабатываем каждый класс и находим подходящий интерфейс
  38.         var processedClasses = combinedProvider
  39.             .Select(static (combined, token) =>
  40.                 ProcessClass(combined.Left, combined.Right))
  41.             .Where(static m => m is not null);
  42.  
  43.         // Шаг 4: Группируем регистрации по интерфейсам для эффективной кодогенерации
  44.         var groupedRegistrations = processedClasses
  45.             .Collect()
  46.             .Select(static (registrations, _) =>
  47.                 registrations
  48.                     .GroupBy(r => r.InterfaceSymbol.ToDisplayString())
  49.                     .Select(g => new InterfaceRegistrationGroup
  50.                     {
  51.                         InterfaceFullName = g.Key,
  52.                         InterfaceSymbol = g.First().InterfaceSymbol,
  53.                         ClassRegistrations = g.ToList(),
  54.                         Lifetime = g.First().Lifetime,
  55.                         RegisterType = g.First().RegisterType
  56.                     })
  57.                     .ToList());
  58.  
  59.         // Шаг 5: Генерируем код
  60.         context.RegisterSourceOutput(groupedRegistrations,
  61.             static (spc, source) => GenerateRegistrationCode(source, spc));
  62.     }
  63.    
  64.     private static bool IsClassCandidate(SyntaxNode node)
  65.     {
  66.         return node is ClassDeclarationSyntax classDecl
  67.                && !classDecl.IsAbstract()
  68.                && !classDecl.IsStatic();
  69.     }
  70.    
  71.     private static ClassInfo? GetClassSymbol(GeneratorSyntaxContext context)
  72.     {
  73.         var classDeclaration = (ClassDeclarationSyntax)context.Node;
  74.         var classSymbol = context.SemanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol;
  75.            
  76.         if (classSymbol == null || classSymbol.IsAbstract)
  77.             return null;
  78.  
  79.         return new ClassInfo
  80.         {
  81.             Symbol = classSymbol,
  82.             FullName = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
  83.             Name = classSymbol.Name
  84.         };
  85.     }
  86.    
  87.     private static ClassRegistration? ProcessClass(ClassInfo? classInfo, Compilation compilation)
  88.     {
  89.         if (!classInfo.HasValue)
  90.             return null;
  91.        
  92.         var classSymbol = classInfo.Value.Symbol;
  93.            
  94.         // Собираем все интерфейсы, которые реализует класс
  95.         var allInterfaces = classSymbol.AllInterfaces;
  96.            
  97.         // Ищем первый подходящий интерфейс
  98.         foreach (var interfaceSymbol in allInterfaces)
  99.         {
  100.             var interfaceFullName = interfaceSymbol.ToDisplayString();
  101.                
  102.             // Проверяем наличие атрибута VContainerAutoRegistration
  103.             var hasAttribute = HasAutoRegistrationAttribute(interfaceSymbol, compilation);
  104.             if (hasAttribute.hasAttribute)
  105.             {
  106.                 return new ClassRegistration
  107.                 {
  108.                     ClassSymbol = classSymbol,
  109.                     InterfaceSymbol = interfaceSymbol,
  110.                     Lifetime = hasAttribute.lifetime,
  111.                     RegisterType = hasAttribute.registerType,
  112.                     FoundBy = "Attribute"
  113.                 };
  114.             }
  115.  
  116.             // Проверяем вхождение в preparedInterfaces
  117.             if (preparedInterfaces.TryGetValue(interfaceFullName, out var preparedSettings))
  118.             {
  119.                 return new ClassRegistration
  120.                 {
  121.                     ClassSymbol = classSymbol,
  122.                     InterfaceSymbol = interfaceSymbol,
  123.                     Lifetime = preparedSettings.Item1,
  124.                     RegisterType = preparedSettings.Item2,
  125.                     FoundBy = "PreparedList"
  126.                 };
  127.             }
  128.         }
  129.  
  130.         return null;
  131.     }
  132.    
  133.     private static (bool hasAttribute, TargetLifetime lifetime, RegisterType registerType)
  134.         HasAutoRegistrationAttribute(INamedTypeSymbol interfaceSymbol, Compilation compilation)
  135.     {
  136.         var autoRegistrationAttribute = compilation.GetTypeByMetadataName(Attribute);
  137.            
  138.         if (autoRegistrationAttribute == null)
  139.             return (false, default, default);
  140.  
  141.         var attribute = interfaceSymbol.GetAttributes()
  142.             .FirstOrDefault(attr =>
  143.                 attr.AttributeClass?.Equals(autoRegistrationAttribute, SymbolEqualityComparer.Default) == true);
  144.            
  145.         if (attribute == null)
  146.             return (false, default, default);
  147.  
  148.         var lifetime = GetLifetimeFromAttribute(attribute);
  149.         var registerType = GetRegisterTypeFromAttribute(attribute);
  150.  
  151.         return (true, lifetime, registerType);
  152.     }
  153.    
  154.     private static TargetLifetime GetLifetimeFromAttribute(AttributeData attribute)
  155.     {
  156.         foreach (var namedArgument in attribute.NamedArguments)
  157.         {
  158.             if (namedArgument.Key == "Lifetime" && namedArgument.Value.Value is int lifetimeValue)
  159.             {
  160.                 return (TargetLifetime)lifetimeValue;
  161.             }
  162.         }
  163.         return TargetLifetime.Transient;
  164.     }
  165.    
  166.     private static RegisterType GetRegisterTypeFromAttribute(AttributeData attribute)
  167.     {
  168.         foreach (var namedArgument in attribute.NamedArguments)
  169.         {
  170.             if (namedArgument.Key == "RegisterAs" && namedArgument.Value.Value is int registerAsValue)
  171.             {
  172.                 return (RegisterType)registerAsValue;
  173.             }
  174.         }
  175.         return RegisterType.ImplementedInterfaces;
  176.     }
  177.    
  178.     private static void GenerateRegistrationCode(
  179.         List<InterfaceRegistrationGroup> groups,
  180.         SourceProductionContext context)
  181.     {
  182.         if (!groups.Any())
  183.             return;
  184.  
  185.         // Генерируем основной файл регистраций
  186.         GenerateMainRegistrationFile(groups, context);
  187.  
  188.         // Генерируем отдельные файлы для каждой группы интерфейсов
  189.         foreach (var group in groups)
  190.         {
  191.             GenerateInterfaceRegistrationFile(group, context);
  192.         }
  193.     }
  194.    
  195.     private static void GenerateMainRegistrationFile(
  196.         List<InterfaceRegistrationGroup> groups,
  197.         SourceProductionContext context)
  198.     {
  199.         var sourceBuilder = StringBuilderPool.Get();
  200.            
  201.         sourceBuilder.AppendLine("// <auto-generated />");
  202.         sourceBuilder.AppendLine("#nullable enable");
  203.         sourceBuilder.AppendLine("using VContainer;");
  204.         sourceBuilder.AppendLine();
  205.         sourceBuilder.AppendLine("namespace VContainer.AutoRegistration.Generated");
  206.         sourceBuilder.AppendLine("{");
  207.         sourceBuilder.AppendLine("    public static partial class AutoGeneratedRegistrations");
  208.         sourceBuilder.AppendLine("    {");
  209.         sourceBuilder.AppendLine("        public static IContainerBuilder RegisterAllAutoDiscoveredTypes(this IContainerBuilder builder)");
  210.         sourceBuilder.AppendLine("        {");
  211.  
  212.         foreach (var group in groups)
  213.         {
  214.             var methodName = GetSafeMethodName(group.InterfaceSymbol);
  215.             sourceBuilder.AppendLine($"            builder.{methodName}();");
  216.         }
  217.  
  218.         sourceBuilder.AppendLine("            return builder;");
  219.         sourceBuilder.AppendLine("        }");
  220.         sourceBuilder.AppendLine("    }");
  221.         sourceBuilder.AppendLine("}");
  222.  
  223.         context.AddSource("AutoGeneratedRegistrations.Main.g.cs",
  224.             SourceText.From(sourceBuilder.ToStringAndReturn(), Encoding.UTF8));
  225.     }
  226.    
  227.     private static string GetSafeMethodName(INamedTypeSymbol interfaceSymbol)
  228.     {
  229.         var interfaceName = SafeInterfaceName(interfaceSymbol);
  230.            
  231.         return $"Register{interfaceName}Implementations";
  232.     }
  233.  
  234.     private static string SafeInterfaceName(INamedTypeSymbol interfaceSymbol)
  235.     {
  236.         return interfaceSymbol.ToDisplayString()
  237.             .Replace(".", "_")
  238.             .Replace("<", "_")
  239.             .Replace(">", "_")
  240.             .Replace(",", "_");
  241.     }
  242.  
  243.     private static string GetSafeFileName(INamedTypeSymbol interfaceSymbol)
  244.     {
  245.         var interfaceName = SafeInterfaceName(interfaceSymbol);
  246.         return interfaceName;
  247.     }
  248.    
  249.     private static void GenerateInterfaceRegistrationFile(
  250.             InterfaceRegistrationGroup group,
  251.             SourceProductionContext context)
  252.         {
  253.             var sourceBuilder = StringBuilderPool.Get();
  254.            
  255.             sourceBuilder.AppendLine("// <auto-generated />");
  256.             sourceBuilder.AppendLine("#nullable enable");
  257.             sourceBuilder.AppendLine("using VContainer;");
  258.             sourceBuilder.AppendLine();
  259.             sourceBuilder.AppendLine("namespace VContainer.AutoRegistration.Generated");
  260.             sourceBuilder.AppendLine("{");
  261.             sourceBuilder.AppendLine("    public static partial class AutoGeneratedRegistrations");
  262.             sourceBuilder.AppendLine("    {");
  263.  
  264.             var methodName = GetSafeMethodName(group.InterfaceSymbol);
  265.             var interfaceFullName = group.InterfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
  266.  
  267.             sourceBuilder.AppendLine($"        public static IContainerBuilder {methodName}(this IContainerBuilder builder)");
  268.             sourceBuilder.AppendLine("        {");
  269.  
  270.             foreach (var registration in group.ClassRegistrations)
  271.             {
  272.                 var className = registration.ClassSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
  273.                
  274.                 sourceBuilder.AppendLine($"            // {registration.ClassSymbol.Name} -> {registration.InterfaceSymbol.Name} ({registration.FoundBy})");
  275.                
  276.                 var registrationLine = registration.RegisterType switch
  277.                 {
  278.                     RegisterType.Self =>
  279.                         $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsSelf();",
  280.                     RegisterType.ImplementedInterfaces =>
  281.                         $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsImplementedInterfaces();",
  282.                     RegisterType.Both =>
  283.                         $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsImplementedInterfaces().AsSelf();",
  284.                     _ => string.Empty
  285.                 };
  286.  
  287.                 if (!string.IsNullOrEmpty(registrationLine))
  288.                 {
  289.                     sourceBuilder.AppendLine($"            {registrationLine}");
  290.                 }
  291.                 sourceBuilder.AppendLine();
  292.             }
  293.  
  294.             sourceBuilder.AppendLine("            return builder;");
  295.             sourceBuilder.AppendLine("        }");
  296.             sourceBuilder.AppendLine("    }");
  297.             sourceBuilder.AppendLine("}");
  298.  
  299.             var fileName = $"AutoGeneratedRegistrations.{GetSafeFileName(group.InterfaceSymbol)}.g.cs";
  300.             context.AddSource(fileName, SourceText.From(sourceBuilder.ToStringAndReturn(), Encoding.UTF8));
  301.         }
  302. }
  303.  
  304. // Вспомогательные структуры
  305. internal struct ClassInfo
  306. {
  307.     public INamedTypeSymbol Symbol { get; set; }
  308.     public string FullName { get; set; }
  309.     public string Name { get; set; }
  310. }
  311.  
  312. internal class ClassRegistration
  313. {
  314.     public INamedTypeSymbol ClassSymbol { get; set; } = null!;
  315.     public INamedTypeSymbol InterfaceSymbol { get; set; } = null!;
  316.     public TargetLifetime Lifetime { get; set; }
  317.     public RegisterType RegisterType { get; set; }
  318.     public string FoundBy { get; set; } = string.Empty;
  319. }
  320.  
  321. internal class InterfaceRegistrationGroup
  322. {
  323.     public string InterfaceFullName { get; set; } = string.Empty;
  324.     public INamedTypeSymbol InterfaceSymbol { get; set; } = null!;
  325.     public List<ClassRegistration> ClassRegistrations { get; set; } = new();
  326.     public TargetLifetime Lifetime { get; set; }
  327.     public RegisterType RegisterType { get; set; }
  328. }
  329.  
  330. public enum TargetLifetime
  331. {
  332.     Transient,
  333.     Scoped,
  334.     Singleton
  335. }
  336.    
  337. public enum RegisterType
  338. {
  339.     Self,
  340.     ImplementedInterfaces,
  341.     Both
  342. }
Advertisement
Add Comment
Please, Sign In to add comment