From b6cf60e9093c47c14cfe64dcc8b1703ee2cfe889 Mon Sep 17 00:00:00 2001 From: Timothy Schenk Date: Sat, 18 Jan 2025 21:32:43 +0100 Subject: [PATCH] rough outline --- .../Examples.cs | 3 +- .../IntermediatePacketHandlerData.cs | 20 +- .../IntermediatePacketStructHandlerData.cs | 12 +- .../PacketMediatorGenerator.cs | 290 ++++++++++-------- 4 files changed, 189 insertions(+), 136 deletions(-) diff --git a/PacketMediator.Generator/PacketMediator.Generator.Sample/Examples.cs b/PacketMediator.Generator/PacketMediator.Generator.Sample/Examples.cs index 226bf9e..714e9bf 100644 --- a/PacketMediator.Generator/PacketMediator.Generator.Sample/Examples.cs +++ b/PacketMediator.Generator/PacketMediator.Generator.Sample/Examples.cs @@ -63,7 +63,8 @@ public class HandlerA : IPacketHandler { } public class TestHandler : IPacketHandler { - public Task HandleAsync(StructB packet, RandomSession session, CancellationToken cancellationToken) { + public async Task HandleAsync(StructB packet, RandomSession session, CancellationToken cancellationToken) { + await Task.Delay(2000, cancellationToken); throw new NotImplementedException(); } } diff --git a/RaiNote.PacketMediator/IntermediatePacketHandlerData.cs b/RaiNote.PacketMediator/IntermediatePacketHandlerData.cs index bb417a7..e40b3a2 100644 --- a/RaiNote.PacketMediator/IntermediatePacketHandlerData.cs +++ b/RaiNote.PacketMediator/IntermediatePacketHandlerData.cs @@ -4,14 +4,26 @@ using Microsoft.CodeAnalysis; namespace RaiNote.PacketMediator; -class IntermediatePacketHandlerData { - public IntermediatePacketHandlerData(INamedTypeSymbol? symbol, string packetHandlerIdentifier, IntermediatePacketStructHandlerData? packetStructHandlerData) { - Symbol = symbol; +internal class IntermediatePacketHandlerData { + public IntermediatePacketHandlerData(string packetHandlerIdentifier, + IntermediatePacketStructHandlerData? packetStructHandlerData) { PacketHandlerIdentifier = packetHandlerIdentifier; PacketStructHandlerData = packetStructHandlerData; } - public INamedTypeSymbol? Symbol { get; set; } public string PacketHandlerIdentifier { get; set; } public IntermediatePacketStructHandlerData? PacketStructHandlerData { get; set; } } + +internal record IntermediatePacketStructData( + Location SymbolLocation, + string PacketStructFullIdentifier, + string EnumValue, + string EnumTypeFullIdentifier, + string EnumMemberIdentifier, + string EnumMaxValue, + bool ImplementsInterface); + +internal record IntermediateHandlerAndStructTuple( + IntermediatePacketHandlerData HandlerData, + IntermediatePacketStructData StructData); diff --git a/RaiNote.PacketMediator/IntermediatePacketStructHandlerData.cs b/RaiNote.PacketMediator/IntermediatePacketStructHandlerData.cs index 72e1658..b7ce1d8 100644 --- a/RaiNote.PacketMediator/IntermediatePacketStructHandlerData.cs +++ b/RaiNote.PacketMediator/IntermediatePacketStructHandlerData.cs @@ -4,12 +4,12 @@ using Microsoft.CodeAnalysis; namespace RaiNote.PacketMediator; -class IntermediatePacketStructHandlerData { - public IntermediatePacketStructHandlerData(ITypeSymbol packetStructSymbol, ITypeSymbol sessionSymbol) { - PacketStructSymbol = packetStructSymbol; - SessionSymbol = sessionSymbol; +internal class IntermediatePacketStructHandlerData { + public IntermediatePacketStructHandlerData(string packetStructFullIdentifier, string sessionFullIdentifier) { + PacketStructFullIdentifier = packetStructFullIdentifier; + SessionFullIdentifier = sessionFullIdentifier; } - public ITypeSymbol PacketStructSymbol { get; set; } - public ITypeSymbol SessionSymbol { get; set; } + public string PacketStructFullIdentifier { get; set; } + public string SessionFullIdentifier { get; set; } } diff --git a/RaiNote.PacketMediator/PacketMediatorGenerator.cs b/RaiNote.PacketMediator/PacketMediatorGenerator.cs index 688f714..16919bc 100644 --- a/RaiNote.PacketMediator/PacketMediatorGenerator.cs +++ b/RaiNote.PacketMediator/PacketMediatorGenerator.cs @@ -24,67 +24,70 @@ public class PacketMediatorGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(ctx => { - ctx.AddSource("PacketMediatorStatic.g.cs", SourceText.From(@" -using System; -using System.Diagnostics; -using System.Threading; -using System.Threading.Tasks; + ctx.AddSource("PacketMediatorStatic.g.cs", + SourceText.From(""" -namespace RaiNote.PacketMediator; -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false)] -public abstract class PacketIdAttribute : Attribute where TPacketIdEnum : Enum -{ - protected PacketIdAttribute(TPacketIdEnum code) - { - Code = code; - } + using System; + using System.Diagnostics; + using System.Threading; + using System.Threading.Tasks; - public TPacketIdEnum Code { get; } -} -public interface IPacket; -public interface IIncomingPacket : IPacket -{ - public void Deserialize(byte[] data); -} -public interface IOutgoingPacket : IPacket -{ - public byte[] Serialize(); -} + namespace RaiNote.PacketMediator; + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false)] + public abstract class PacketIdAttribute : Attribute where TPacketIdEnum : Enum + { + protected PacketIdAttribute(TPacketIdEnum code) + { + Code = code; + } -public interface IBidirectionalPacket : IOutgoingPacket, IIncomingPacket; + public TPacketIdEnum Code { get; } + } + public interface IPacket; + public interface IIncomingPacket : IPacket + { + public void Deserialize(byte[] data); + } + public interface IOutgoingPacket : IPacket + { + public byte[] Serialize(); + } -public interface IPacketHandler : IPacketHandler - where TIncomingPacket : IIncomingPacket -{ - async Task IPacketHandler.TryHandleAsync(IIncomingPacket packet, TSession session, - CancellationToken cancellationToken) - { - if (packet is not TIncomingPacket tPacket) - { - return false; - } + public interface IBidirectionalPacket : IOutgoingPacket, IIncomingPacket; - using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync)); - activity?.AddTag(""Handler"", ToString()); - activity?.AddTag(""Packet"", packet.ToString()); - await HandleAsync(tPacket, session, cancellationToken); + public interface IPacketHandler : IPacketHandler + where TIncomingPacket : IIncomingPacket + { + async Task IPacketHandler.TryHandleAsync(IIncomingPacket packet, TSession session, + CancellationToken cancellationToken) + { + if (packet is not TIncomingPacket tPacket) + { + return false; + } - return true; - } + using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync)); + activity?.AddTag("Handler", ToString()); + activity?.AddTag("Packet", packet.ToString()); + await HandleAsync(tPacket, session, cancellationToken); - public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken); -} + return true; + } -public interface IPacketHandler -{ - Task TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken); -} -", Encoding.UTF8)); + public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken); + } + + public interface IPacketHandler + { + Task TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken); + } + + """, Encoding.UTF8)); }); // Find all struct declarations var structsWithAttributes = context.SyntaxProvider .CreateSyntaxProvider( - predicate: (node, _) => node is StructDeclarationSyntax { AttributeLists.Count: > 0}, + predicate: (node, _) => node is StructDeclarationSyntax { AttributeLists.Count: > 0 }, transform: (syntaxContext, cancellationToken) => { var structDeclaration = (StructDeclarationSyntax)syntaxContext.Node; var model = syntaxContext.SemanticModel; @@ -97,8 +100,13 @@ public interface IPacketHandler }; var implementsInterface = symbol != null && symbol.AllInterfaces .Any(i => requiredInterfaces.Contains(i.Name, StringComparer.Ordinal)); - if (!implementsInterface) - return default; + if (!implementsInterface) { + // TODO: https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#issue-diagnostics + // or Analyzer + var diagnostic = Diagnostic.Create(_rpmGen001Diagnostic, symbol?.Locations.First(), + symbol?.ToDisplayString()); + } + // Check for the marker attribute var attribute = symbol?.GetAttributes() .FirstOrDefault(attr => { @@ -116,7 +124,7 @@ public interface IPacketHandler return false; }); if (attribute == null) { - return default; + return null; } var attributeConstructorArgument = attribute.ConstructorArguments[0]; @@ -130,95 +138,94 @@ public interface IPacketHandler var enumMaxValue = enumType?.GetMembers() .OfType().Max(x => x.ConstantValue); - return (symbol, structName: structDeclaration.Identifier.Text, value: enumValue, enumType, - enumMember, - enumMaxValue, implementsInterface); + if (symbol == null || enumMember == null || enumMaxValue == null || enumType == null || + enumValue == null) + return null; + + var intermediatePacketStructData = new IntermediatePacketStructData(symbol.Locations.First(), + symbol.ToDisplayString(), + enumValue.ToString(), + enumType.ToDisplayString(), + enumMember.ToDisplayString(), + enumMaxValue.ToString(), implementsInterface); + + return intermediatePacketStructData; }) - .Where(result => result != default); + .Where(result => result != null); - var packetHandlers = context.SyntaxProvider.CreateSyntaxProvider( + var packetHandlerValues = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => node is ClassDeclarationSyntax, - (syntaxContext, cancellationToken) => { - var classDeclaration = (ClassDeclarationSyntax)syntaxContext.Node; - var model = syntaxContext.SemanticModel; - var symbol = - ModelExtensions.GetDeclaredSymbol(model, classDeclaration, cancellationToken: cancellationToken) as - INamedTypeSymbol; - - var packetStruct = (symbol?.Interfaces.Select(interfaceSyntax => { - if (!interfaceSyntax.Name.SequenceEqual("IPacketHandler")) { - return null; - } - - if (interfaceSyntax.TypeArguments.Length != 2) - return null; - - var genericStructArgument = interfaceSyntax.TypeArguments[0]; - var genericSessionArgument = interfaceSyntax.TypeArguments[1]; - - return new IntermediatePacketStructHandlerData(genericStructArgument, genericSessionArgument); - }) ?? throw new InvalidOperationException("1")).FirstOrDefault(x => x != null); - - if (packetStruct == null) - return null; - return new IntermediatePacketHandlerData(symbol, classDeclaration.Identifier.Text, packetStruct); - }).Where(result => result != null); + TransformPacketHandlers).Where(result => result != null); var location = context.SyntaxProvider.CreateSyntaxProvider(predicate: static (node, _) => InterceptorPredicate(node), transform: static (context, ct) => InterceptorTransform(context, ct)) .Where(candidate => candidate is not null); - var combinedResults = structsWithAttributes.Collect().Combine(packetHandlers.Collect()); + var combinedResults = structsWithAttributes.Collect().Combine(packetHandlerValues.Collect()).Select( + static (tuple, cancellationToken) => { + var (structDatas, handlerDatas) = tuple; + + var matchingData = handlerDatas.Select(handlerData => { + if (handlerData == null) + return null; + var structs = structDatas.Where(sData => { + var equals = + sData != null && sData.PacketStructFullIdentifier.Equals(handlerData.PacketStructHandlerData + ?.PacketStructFullIdentifier); + + return equals; + }).FirstOrDefault(); + if (structs == null) + return null; + var intermediateHandlerAndStructTuple = new IntermediateHandlerAndStructTuple(handlerData, structs); + return intermediateHandlerAndStructTuple; + }); + + var intermediateHandlerAndStructTuples = matchingData.Where(x => x != null); + return intermediateHandlerAndStructTuples; + }); // Collect and generate the dictionary context.RegisterSourceOutput(combinedResults, (ctx, result) => { - var (packetStructs, packetHandlerDatas) = result; + var combinedInfo = result.ToList(); - var combinedInfo = packetHandlerDatas.Select(handler => { - var respectiveStruct = packetStructs.FirstOrDefault(pStruct => - pStruct.structName.SequenceEqual(handler?.PacketStructHandlerData?.PacketStructSymbol.Name ?? - string.Empty)); - return (handler, respectiveStruct); - }); - var packetHandlerData = packetHandlerDatas.FirstOrDefault(); + if (combinedInfo.Count <= 0) + return; + + var packetHandlerData = combinedInfo.First()?.HandlerData; var usedValues = new List(); - var highestValue = long.Parse(packetStructs.FirstOrDefault().enumMaxValue?.ToString()); + var intermediatePacketStructData = combinedInfo.First()?.StructData; + if (intermediatePacketStructData?.EnumMaxValue == null) + return; + var highestValue = long.Parse(intermediatePacketStructData.EnumMaxValue); var ms = new MemoryStream(); var sw = new StreamWriter(ms, Encoding.UTF8); sw.AutoFlush = true; - var enumTypeString = packetStructs.FirstOrDefault().enumType?.ToDisplayString(); - var sessionTypeString = packetHandlerData?.PacketStructHandlerData?.SessionSymbol.ToDisplayString(); + var enumTypeString = intermediatePacketStructData?.EnumTypeFullIdentifier; + var sessionTypeString = packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier; sw.WriteLine($$""" - using System; - using System.Threading; - using System.Threading.Tasks; - public static class PacketHandlerMediator - { - public async static Task Handle(IServiceProvider serviceProvider, byte[] data,{{enumTypeString}} opcode ,{{packetHandlerData?.PacketStructHandlerData?.SessionSymbol}} session, CancellationToken cancellationToken){ - switch(opcode) - { - """); + using System; + using System.Threading; + using System.Threading.Tasks; + public static class PacketHandlerMediator + { + public async static Task Handle(IServiceProvider serviceProvider, byte[] data,{{enumTypeString}} opcode ,{{packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier}} session, CancellationToken cancellationToken){ + switch(opcode) + { + """); - foreach (var ((handler, (symbol, structName, value, _, enumMember, _, implementsInterface)), i) in - combinedInfo.Select((value, i)=> (value, i))) { - if (!implementsInterface) { - var diagnostic = Diagnostic.Create(_rpmGen001Diagnostic, symbol?.Locations.FirstOrDefault(), - structName); - - ctx.ReportDiagnostic(diagnostic); - continue; - } - - var tempVal = long.Parse(value?.ToString()!, + var valueTuples = combinedInfo.Select((value, i) => (value, i)); + foreach (var ((handlerData, packetStructData), i) in valueTuples) { + var tempVal = long.Parse(packetStructData.EnumValue, new NumberFormatInfo()); usedValues.Add(tempVal); sw.WriteLine($""" - case {enumMember}: - var packet = new {handler?.PacketStructHandlerData?.PacketStructSymbol.ToDisplayString()}(); + case {packetStructData.EnumMemberIdentifier}: + var packet = new {handlerData.PacketHandlerIdentifier}(); packet.Deserialize(data); - _ = {handler?.Symbol?.ToDisplayString()}.HandleAsync(packet, session, cancellationToken); + _ = {handlerData.PacketHandlerIdentifier}.HandleAsync(packet, session, cancellationToken); return; """); } @@ -247,7 +254,8 @@ public interface IPacketHandler """); sw.Flush(); - ctx.AddSource("PacketHandlerMediator.g.cs", SourceText.From(sw.BaseStream, Encoding.UTF8, canBeEmbedded: true)); + ctx.AddSource("PacketHandlerMediator.g.cs", + SourceText.From(sw.BaseStream, Encoding.UTF8, canBeEmbedded: true)); sw.Close(); ms.Close(); var stringWriter = new StringWriter(); @@ -255,31 +263,63 @@ public interface IPacketHandler idWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;"); idWriter.WriteLine("public static class ServiceExtensions{"); idWriter.Indent++; - idWriter.WriteLine("public static void AddPacketHandlerServices(this IServiceCollection serviceCollection){"); + idWriter.WriteLine( + "public static void AddPacketHandlerServices(this IServiceCollection serviceCollection){"); idWriter.Indent++; idWriter.WriteLine("// PacketHandler Service Generation"); - foreach (var handlerData in packetHandlerDatas) { - idWriter.WriteLine($"serviceCollection.AddScoped<{handlerData?.Symbol?.ToDisplayString()}>();"); + foreach (var (handlerData, _) in combinedInfo) { + idWriter.WriteLine($"serviceCollection.AddScoped<{handlerData.PacketHandlerIdentifier}>();"); } + idWriter.Indent--; idWriter.WriteLine("}"); idWriter.Indent--; idWriter.WriteLine("}"); - ctx.AddSource("ServiceExtensions.g.cs", SourceText.From(stringWriter.ToString(),Encoding.UTF8)); + ctx.AddSource("ServiceExtensions.g.cs", SourceText.From(stringWriter.ToString(), Encoding.UTF8)); }); } + private static IntermediatePacketHandlerData? TransformPacketHandlers(GeneratorSyntaxContext syntaxContext, + CancellationToken cancellationToken) { + var classDeclaration = (ClassDeclarationSyntax)syntaxContext.Node; + var model = syntaxContext.SemanticModel; + var symbol = + ModelExtensions.GetDeclaredSymbol(model, classDeclaration, cancellationToken: cancellationToken) as + INamedTypeSymbol; + var packetStruct = (symbol.Interfaces.Select(interfaceSyntax => { + if (!interfaceSyntax.Name.Equals("IPacketHandler")) { + return null; + } + + if (interfaceSyntax.TypeArguments.Length != 2) return null; + + var genericStructArgument = interfaceSyntax.TypeArguments[0]; + var genericSessionArgument = interfaceSyntax.TypeArguments[1]; + + return new IntermediatePacketStructHandlerData(genericStructArgument.ToDisplayString(), + genericSessionArgument.ToDisplayString()); + }) ?? throw new InvalidOperationException("1")).FirstOrDefault(x => x != null); + + if (packetStruct == null) return null; + var intermediatePacketHandlerData = new IntermediatePacketHandlerData(symbol.ToDisplayString(), packetStruct); + + return intermediatePacketHandlerData; + } + // https://andrewlock.net/creating-a-source-generator-part-11-implementing-an-interceptor-with-a-source-generator/ - private static CandidateInvocation? InterceptorTransform(GeneratorSyntaxContext context, CancellationToken cancellationToken) { + private static CandidateInvocation? InterceptorTransform(GeneratorSyntaxContext context, + CancellationToken cancellationToken) { if (context.Node is InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax { Name: { } nameSyntax } } invocation && - context.SemanticModel.GetOperation(context.Node, cancellationToken) is IInvocationOperation targetOperation && + context.SemanticModel.GetOperation(context.Node, + cancellationToken) is IInvocationOperation targetOperation && targetOperation.TargetMethod is { Name: "AddPacketHandlerServices", ContainingNamespace: { Name: "RaiNote.PacketMediator" } } ) { #pragma warning disable RSEXPERIMENTAL002 // / Experimental interceptable location API - if (context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken: cancellationToken) is { } location) { + if (context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken: cancellationToken) is + { } location) { // Return the location details and the full type details return new CandidateInvocation(location); }