Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System;
- using System.Collections.Generic;
- using System.Collections.Immutable;
- using System.Linq;
- using System.Text;
- using Microsoft.CodeAnalysis;
- using Microsoft.CodeAnalysis.CSharp;
- using Microsoft.CodeAnalysis.CSharp.Syntax;
- using Microsoft.CodeAnalysis.Text;
- using SourceGenerators.Utils;
- namespace SourceGenerators;
- [Generator]
- public class VContainerAutoRegistrationGenerator : IIncrementalGenerator
- {
- private const string Attribute = "CodeBase.CompositionRoot.Helpers.VContainerAutoRegistrationAttribute";
- private static Dictionary<string, (TargetLifetime, RegisterType)> preparedInterfaces = new()
- {
- {"Scellecs.Morpeh.IInitializer", (TargetLifetime.Singleton, RegisterType.Self)},
- };
- public void Initialize(IncrementalGeneratorInitializationContext context)
- {
- // Шаг 1: Собираем все классы-кандидаты
- var classProvider = context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: static (s, _) => IsClassCandidate(s),
- transform: static (ctx, _) => GetClassSymbol(ctx))
- .Where(static m => m is not null);
- // Шаг 2: Комбинируем классы с информацией о компиляции
- var combinedProvider = classProvider
- .Combine(context.CompilationProvider);
- // Шаг 3: Обрабатываем каждый класс и находим подходящий интерфейс
- var processedClasses = combinedProvider
- .Select(static (combined, token) =>
- ProcessClass(combined.Left, combined.Right))
- .Where(static m => m is not null);
- // Шаг 4: Группируем регистрации по интерфейсам для эффективной кодогенерации
- var groupedRegistrations = processedClasses
- .Collect()
- .Select(static (registrations, _) =>
- registrations
- .GroupBy(r => r.InterfaceSymbol.ToDisplayString())
- .Select(g => new InterfaceRegistrationGroup
- {
- InterfaceFullName = g.Key,
- InterfaceSymbol = g.First().InterfaceSymbol,
- ClassRegistrations = g.ToList(),
- Lifetime = g.First().Lifetime,
- RegisterType = g.First().RegisterType
- })
- .ToList());
- // Шаг 5: Генерируем код
- context.RegisterSourceOutput(groupedRegistrations,
- static (spc, source) => GenerateRegistrationCode(source, spc));
- }
- private static bool IsClassCandidate(SyntaxNode node)
- {
- return node is ClassDeclarationSyntax classDecl
- && !classDecl.IsAbstract()
- && !classDecl.IsStatic();
- }
- private static ClassInfo? GetClassSymbol(GeneratorSyntaxContext context)
- {
- var classDeclaration = (ClassDeclarationSyntax)context.Node;
- var classSymbol = context.SemanticModel.GetDeclaredSymbol(classDeclaration) as INamedTypeSymbol;
- if (classSymbol == null || classSymbol.IsAbstract)
- return null;
- return new ClassInfo
- {
- Symbol = classSymbol,
- FullName = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
- Name = classSymbol.Name
- };
- }
- private static ClassRegistration? ProcessClass(ClassInfo? classInfo, Compilation compilation)
- {
- if (!classInfo.HasValue)
- return null;
- var classSymbol = classInfo.Value.Symbol;
- // Собираем все интерфейсы, которые реализует класс
- var allInterfaces = classSymbol.AllInterfaces;
- // Ищем первый подходящий интерфейс
- foreach (var interfaceSymbol in allInterfaces)
- {
- var interfaceFullName = interfaceSymbol.ToDisplayString();
- // Проверяем наличие атрибута VContainerAutoRegistration
- var hasAttribute = HasAutoRegistrationAttribute(interfaceSymbol, compilation);
- if (hasAttribute.hasAttribute)
- {
- return new ClassRegistration
- {
- ClassSymbol = classSymbol,
- InterfaceSymbol = interfaceSymbol,
- Lifetime = hasAttribute.lifetime,
- RegisterType = hasAttribute.registerType,
- FoundBy = "Attribute"
- };
- }
- // Проверяем вхождение в preparedInterfaces
- if (preparedInterfaces.TryGetValue(interfaceFullName, out var preparedSettings))
- {
- return new ClassRegistration
- {
- ClassSymbol = classSymbol,
- InterfaceSymbol = interfaceSymbol,
- Lifetime = preparedSettings.Item1,
- RegisterType = preparedSettings.Item2,
- FoundBy = "PreparedList"
- };
- }
- }
- return null;
- }
- private static (bool hasAttribute, TargetLifetime lifetime, RegisterType registerType)
- HasAutoRegistrationAttribute(INamedTypeSymbol interfaceSymbol, Compilation compilation)
- {
- var autoRegistrationAttribute = compilation.GetTypeByMetadataName(Attribute);
- if (autoRegistrationAttribute == null)
- return (false, default, default);
- var attribute = interfaceSymbol.GetAttributes()
- .FirstOrDefault(attr =>
- attr.AttributeClass?.Equals(autoRegistrationAttribute, SymbolEqualityComparer.Default) == true);
- if (attribute == null)
- return (false, default, default);
- var lifetime = GetLifetimeFromAttribute(attribute);
- var registerType = GetRegisterTypeFromAttribute(attribute);
- return (true, lifetime, registerType);
- }
- private static TargetLifetime GetLifetimeFromAttribute(AttributeData attribute)
- {
- foreach (var namedArgument in attribute.NamedArguments)
- {
- if (namedArgument.Key == "Lifetime" && namedArgument.Value.Value is int lifetimeValue)
- {
- return (TargetLifetime)lifetimeValue;
- }
- }
- return TargetLifetime.Transient;
- }
- private static RegisterType GetRegisterTypeFromAttribute(AttributeData attribute)
- {
- foreach (var namedArgument in attribute.NamedArguments)
- {
- if (namedArgument.Key == "RegisterAs" && namedArgument.Value.Value is int registerAsValue)
- {
- return (RegisterType)registerAsValue;
- }
- }
- return RegisterType.ImplementedInterfaces;
- }
- private static void GenerateRegistrationCode(
- List<InterfaceRegistrationGroup> groups,
- SourceProductionContext context)
- {
- if (!groups.Any())
- return;
- // Генерируем основной файл регистраций
- GenerateMainRegistrationFile(groups, context);
- // Генерируем отдельные файлы для каждой группы интерфейсов
- foreach (var group in groups)
- {
- GenerateInterfaceRegistrationFile(group, context);
- }
- }
- private static void GenerateMainRegistrationFile(
- List<InterfaceRegistrationGroup> groups,
- SourceProductionContext context)
- {
- var sourceBuilder = StringBuilderPool.Get();
- sourceBuilder.AppendLine("// <auto-generated />");
- sourceBuilder.AppendLine("#nullable enable");
- sourceBuilder.AppendLine("using VContainer;");
- sourceBuilder.AppendLine();
- sourceBuilder.AppendLine("namespace VContainer.AutoRegistration.Generated");
- sourceBuilder.AppendLine("{");
- sourceBuilder.AppendLine(" public static partial class AutoGeneratedRegistrations");
- sourceBuilder.AppendLine(" {");
- sourceBuilder.AppendLine(" public static IContainerBuilder RegisterAllAutoDiscoveredTypes(this IContainerBuilder builder)");
- sourceBuilder.AppendLine(" {");
- foreach (var group in groups)
- {
- var methodName = GetSafeMethodName(group.InterfaceSymbol);
- sourceBuilder.AppendLine($" builder.{methodName}();");
- }
- sourceBuilder.AppendLine(" return builder;");
- sourceBuilder.AppendLine(" }");
- sourceBuilder.AppendLine(" }");
- sourceBuilder.AppendLine("}");
- context.AddSource("AutoGeneratedRegistrations.Main.g.cs",
- SourceText.From(sourceBuilder.ToStringAndReturn(), Encoding.UTF8));
- }
- private static string GetSafeMethodName(INamedTypeSymbol interfaceSymbol)
- {
- var interfaceName = SafeInterfaceName(interfaceSymbol);
- return $"Register{interfaceName}Implementations";
- }
- private static string SafeInterfaceName(INamedTypeSymbol interfaceSymbol)
- {
- return interfaceSymbol.ToDisplayString()
- .Replace(".", "_")
- .Replace("<", "_")
- .Replace(">", "_")
- .Replace(",", "_");
- }
- private static string GetSafeFileName(INamedTypeSymbol interfaceSymbol)
- {
- var interfaceName = SafeInterfaceName(interfaceSymbol);
- return interfaceName;
- }
- private static void GenerateInterfaceRegistrationFile(
- InterfaceRegistrationGroup group,
- SourceProductionContext context)
- {
- var sourceBuilder = StringBuilderPool.Get();
- sourceBuilder.AppendLine("// <auto-generated />");
- sourceBuilder.AppendLine("#nullable enable");
- sourceBuilder.AppendLine("using VContainer;");
- sourceBuilder.AppendLine();
- sourceBuilder.AppendLine("namespace VContainer.AutoRegistration.Generated");
- sourceBuilder.AppendLine("{");
- sourceBuilder.AppendLine(" public static partial class AutoGeneratedRegistrations");
- sourceBuilder.AppendLine(" {");
- var methodName = GetSafeMethodName(group.InterfaceSymbol);
- var interfaceFullName = group.InterfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
- sourceBuilder.AppendLine($" public static IContainerBuilder {methodName}(this IContainerBuilder builder)");
- sourceBuilder.AppendLine(" {");
- foreach (var registration in group.ClassRegistrations)
- {
- var className = registration.ClassSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
- sourceBuilder.AppendLine($" // {registration.ClassSymbol.Name} -> {registration.InterfaceSymbol.Name} ({registration.FoundBy})");
- var registrationLine = registration.RegisterType switch
- {
- RegisterType.Self =>
- $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsSelf();",
- RegisterType.ImplementedInterfaces =>
- $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsImplementedInterfaces();",
- RegisterType.Both =>
- $"builder.Register<{className}>(Lifetime.{registration.Lifetime}).AsImplementedInterfaces().AsSelf();",
- _ => string.Empty
- };
- if (!string.IsNullOrEmpty(registrationLine))
- {
- sourceBuilder.AppendLine($" {registrationLine}");
- }
- sourceBuilder.AppendLine();
- }
- sourceBuilder.AppendLine(" return builder;");
- sourceBuilder.AppendLine(" }");
- sourceBuilder.AppendLine(" }");
- sourceBuilder.AppendLine("}");
- var fileName = $"AutoGeneratedRegistrations.{GetSafeFileName(group.InterfaceSymbol)}.g.cs";
- context.AddSource(fileName, SourceText.From(sourceBuilder.ToStringAndReturn(), Encoding.UTF8));
- }
- }
- // Вспомогательные структуры
- internal struct ClassInfo
- {
- public INamedTypeSymbol Symbol { get; set; }
- public string FullName { get; set; }
- public string Name { get; set; }
- }
- internal class ClassRegistration
- {
- public INamedTypeSymbol ClassSymbol { get; set; } = null!;
- public INamedTypeSymbol InterfaceSymbol { get; set; } = null!;
- public TargetLifetime Lifetime { get; set; }
- public RegisterType RegisterType { get; set; }
- public string FoundBy { get; set; } = string.Empty;
- }
- internal class InterfaceRegistrationGroup
- {
- public string InterfaceFullName { get; set; } = string.Empty;
- public INamedTypeSymbol InterfaceSymbol { get; set; } = null!;
- public List<ClassRegistration> ClassRegistrations { get; set; } = new();
- public TargetLifetime Lifetime { get; set; }
- public RegisterType RegisterType { get; set; }
- }
- public enum TargetLifetime
- {
- Transient,
- Scoped,
- Singleton
- }
- public enum RegisterType
- {
- Self,
- ImplementedInterfaces,
- Both
- }
Advertisement
Add Comment
Please, Sign In to add comment