rough outline

This commit is contained in:
Timothy Schenk 2025-01-18 21:32:43 +01:00
parent 6386a9ec16
commit b6cf60e909
Signed by: rainote
SSH key fingerprint: SHA256:pnkNSDwpAnaip00xaZlVFHKKsS7T8UtOomMzvs0yITE
4 changed files with 189 additions and 136 deletions

View file

@ -63,7 +63,8 @@ public class HandlerA : IPacketHandler<StructA, RandomSession> {
} }
public class TestHandler : IPacketHandler<StructB, RandomSession> { public class TestHandler : IPacketHandler<StructB, RandomSession> {
public Task HandleAsync(StructB packet, RandomSession session, CancellationToken cancellationToken) { public async Task HandleAsync(StructB packet, RandomSession session, CancellationToken cancellationToken) {
await Task.Delay(2000, cancellationToken);
throw new NotImplementedException(); throw new NotImplementedException();
} }
} }

View file

@ -4,14 +4,26 @@ using Microsoft.CodeAnalysis;
namespace RaiNote.PacketMediator; namespace RaiNote.PacketMediator;
class IntermediatePacketHandlerData { internal class IntermediatePacketHandlerData {
public IntermediatePacketHandlerData(INamedTypeSymbol? symbol, string packetHandlerIdentifier, IntermediatePacketStructHandlerData? packetStructHandlerData) { public IntermediatePacketHandlerData(string packetHandlerIdentifier,
Symbol = symbol; IntermediatePacketStructHandlerData? packetStructHandlerData) {
PacketHandlerIdentifier = packetHandlerIdentifier; PacketHandlerIdentifier = packetHandlerIdentifier;
PacketStructHandlerData = packetStructHandlerData; PacketStructHandlerData = packetStructHandlerData;
} }
public INamedTypeSymbol? Symbol { get; set; }
public string PacketHandlerIdentifier { get; set; } public string PacketHandlerIdentifier { get; set; }
public IntermediatePacketStructHandlerData? PacketStructHandlerData { get; set; } public IntermediatePacketStructHandlerData? PacketStructHandlerData { get; set; }
} }
internal record IntermediatePacketStructData(
Location SymbolLocation,
string PacketStructFullIdentifier,
string EnumValue,
string EnumTypeFullIdentifier,
string EnumMemberIdentifier,
string EnumMaxValue,
bool ImplementsInterface);
internal record IntermediateHandlerAndStructTuple(
IntermediatePacketHandlerData HandlerData,
IntermediatePacketStructData StructData);

View file

@ -4,12 +4,12 @@ using Microsoft.CodeAnalysis;
namespace RaiNote.PacketMediator; namespace RaiNote.PacketMediator;
class IntermediatePacketStructHandlerData { internal class IntermediatePacketStructHandlerData {
public IntermediatePacketStructHandlerData(ITypeSymbol packetStructSymbol, ITypeSymbol sessionSymbol) { public IntermediatePacketStructHandlerData(string packetStructFullIdentifier, string sessionFullIdentifier) {
PacketStructSymbol = packetStructSymbol; PacketStructFullIdentifier = packetStructFullIdentifier;
SessionSymbol = sessionSymbol; SessionFullIdentifier = sessionFullIdentifier;
} }
public ITypeSymbol PacketStructSymbol { get; set; } public string PacketStructFullIdentifier { get; set; }
public ITypeSymbol SessionSymbol { get; set; } public string SessionFullIdentifier { get; set; }
} }

View file

@ -24,38 +24,40 @@ public class PacketMediatorGenerator : IIncrementalGenerator {
public void Initialize(IncrementalGeneratorInitializationContext context) { public void Initialize(IncrementalGeneratorInitializationContext context) {
context.RegisterPostInitializationOutput(ctx => { context.RegisterPostInitializationOutput(ctx => {
ctx.AddSource("PacketMediatorStatic.g.cs", SourceText.From(@" ctx.AddSource("PacketMediatorStatic.g.cs",
using System; SourceText.From("""
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace RaiNote.PacketMediator; using System;
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false)] using System.Diagnostics;
public abstract class PacketIdAttribute<TPacketIdEnum> : Attribute where TPacketIdEnum : Enum using System.Threading;
{ using System.Threading.Tasks;
namespace RaiNote.PacketMediator;
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false)]
public abstract class PacketIdAttribute<TPacketIdEnum> : Attribute where TPacketIdEnum : Enum
{
protected PacketIdAttribute(TPacketIdEnum code) protected PacketIdAttribute(TPacketIdEnum code)
{ {
Code = code; Code = code;
} }
public TPacketIdEnum Code { get; } public TPacketIdEnum Code { get; }
} }
public interface IPacket; public interface IPacket;
public interface IIncomingPacket : IPacket public interface IIncomingPacket : IPacket
{ {
public void Deserialize(byte[] data); public void Deserialize(byte[] data);
} }
public interface IOutgoingPacket : IPacket public interface IOutgoingPacket : IPacket
{ {
public byte[] Serialize(); public byte[] Serialize();
} }
public interface IBidirectionalPacket : IOutgoingPacket, IIncomingPacket; public interface IBidirectionalPacket : IOutgoingPacket, IIncomingPacket;
public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandler<TSession> public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandler<TSession>
where TIncomingPacket : IIncomingPacket where TIncomingPacket : IIncomingPacket
{ {
async Task<bool> IPacketHandler<TSession>.TryHandleAsync(IIncomingPacket packet, TSession session, async Task<bool> IPacketHandler<TSession>.TryHandleAsync(IIncomingPacket packet, TSession session,
CancellationToken cancellationToken) CancellationToken cancellationToken)
{ {
@ -65,26 +67,27 @@ public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandle
} }
using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync)); using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync));
activity?.AddTag(""Handler"", ToString()); activity?.AddTag("Handler", ToString());
activity?.AddTag(""Packet"", packet.ToString()); activity?.AddTag("Packet", packet.ToString());
await HandleAsync(tPacket, session, cancellationToken); await HandleAsync(tPacket, session, cancellationToken);
return true; return true;
} }
public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken); public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken);
} }
public interface IPacketHandler<in TSession> public interface IPacketHandler<in TSession>
{ {
Task<bool> TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken); Task<bool> TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken);
} }
", Encoding.UTF8));
""", Encoding.UTF8));
}); });
// Find all struct declarations // Find all struct declarations
var structsWithAttributes = context.SyntaxProvider var structsWithAttributes = context.SyntaxProvider
.CreateSyntaxProvider( .CreateSyntaxProvider(
predicate: (node, _) => node is StructDeclarationSyntax { AttributeLists.Count: > 0}, predicate: (node, _) => node is StructDeclarationSyntax { AttributeLists.Count: > 0 },
transform: (syntaxContext, cancellationToken) => { transform: (syntaxContext, cancellationToken) => {
var structDeclaration = (StructDeclarationSyntax)syntaxContext.Node; var structDeclaration = (StructDeclarationSyntax)syntaxContext.Node;
var model = syntaxContext.SemanticModel; var model = syntaxContext.SemanticModel;
@ -97,8 +100,13 @@ public interface IPacketHandler<in TSession>
}; };
var implementsInterface = symbol != null && symbol.AllInterfaces var implementsInterface = symbol != null && symbol.AllInterfaces
.Any(i => requiredInterfaces.Contains(i.Name, StringComparer.Ordinal)); .Any(i => requiredInterfaces.Contains(i.Name, StringComparer.Ordinal));
if (!implementsInterface) if (!implementsInterface) {
return default; // TODO: https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#issue-diagnostics
// or Analyzer
var diagnostic = Diagnostic.Create(_rpmGen001Diagnostic, symbol?.Locations.First(),
symbol?.ToDisplayString());
}
// Check for the marker attribute // Check for the marker attribute
var attribute = symbol?.GetAttributes() var attribute = symbol?.GetAttributes()
.FirstOrDefault(attr => { .FirstOrDefault(attr => {
@ -116,7 +124,7 @@ public interface IPacketHandler<in TSession>
return false; return false;
}); });
if (attribute == null) { if (attribute == null) {
return default; return null;
} }
var attributeConstructorArgument = attribute.ConstructorArguments[0]; var attributeConstructorArgument = attribute.ConstructorArguments[0];
@ -130,95 +138,94 @@ public interface IPacketHandler<in TSession>
var enumMaxValue = enumType?.GetMembers() var enumMaxValue = enumType?.GetMembers()
.OfType<IFieldSymbol>().Max(x => x.ConstantValue); .OfType<IFieldSymbol>().Max(x => x.ConstantValue);
return (symbol, structName: structDeclaration.Identifier.Text, value: enumValue, enumType, if (symbol == null || enumMember == null || enumMaxValue == null || enumType == null ||
enumMember, enumValue == null)
enumMaxValue, implementsInterface); return null;
var intermediatePacketStructData = new IntermediatePacketStructData(symbol.Locations.First(),
symbol.ToDisplayString(),
enumValue.ToString(),
enumType.ToDisplayString(),
enumMember.ToDisplayString(),
enumMaxValue.ToString(), implementsInterface);
return intermediatePacketStructData;
}) })
.Where(result => result != default); .Where(result => result != null);
var packetHandlers = context.SyntaxProvider.CreateSyntaxProvider( var packetHandlerValues = context.SyntaxProvider.CreateSyntaxProvider(
predicate: (node, _) => node is ClassDeclarationSyntax, predicate: (node, _) => node is ClassDeclarationSyntax,
(syntaxContext, cancellationToken) => { TransformPacketHandlers).Where(result => result != null);
var classDeclaration = (ClassDeclarationSyntax)syntaxContext.Node;
var model = syntaxContext.SemanticModel;
var symbol =
ModelExtensions.GetDeclaredSymbol(model, classDeclaration, cancellationToken: cancellationToken) as
INamedTypeSymbol;
var packetStruct = (symbol?.Interfaces.Select(interfaceSyntax => {
if (!interfaceSyntax.Name.SequenceEqual("IPacketHandler")) {
return null;
}
if (interfaceSyntax.TypeArguments.Length != 2)
return null;
var genericStructArgument = interfaceSyntax.TypeArguments[0];
var genericSessionArgument = interfaceSyntax.TypeArguments[1];
return new IntermediatePacketStructHandlerData(genericStructArgument, genericSessionArgument);
}) ?? throw new InvalidOperationException("1")).FirstOrDefault(x => x != null);
if (packetStruct == null)
return null;
return new IntermediatePacketHandlerData(symbol, classDeclaration.Identifier.Text, packetStruct);
}).Where(result => result != null);
var location = var location =
context.SyntaxProvider.CreateSyntaxProvider(predicate: static (node, _) => InterceptorPredicate(node), context.SyntaxProvider.CreateSyntaxProvider(predicate: static (node, _) => InterceptorPredicate(node),
transform: static (context, ct) => InterceptorTransform(context, ct)) transform: static (context, ct) => InterceptorTransform(context, ct))
.Where(candidate => candidate is not null); .Where(candidate => candidate is not null);
var combinedResults = structsWithAttributes.Collect().Combine(packetHandlers.Collect()); var combinedResults = structsWithAttributes.Collect().Combine(packetHandlerValues.Collect()).Select(
static (tuple, cancellationToken) => {
var (structDatas, handlerDatas) = tuple;
var matchingData = handlerDatas.Select(handlerData => {
if (handlerData == null)
return null;
var structs = structDatas.Where(sData => {
var equals =
sData != null && sData.PacketStructFullIdentifier.Equals(handlerData.PacketStructHandlerData
?.PacketStructFullIdentifier);
return equals;
}).FirstOrDefault();
if (structs == null)
return null;
var intermediateHandlerAndStructTuple = new IntermediateHandlerAndStructTuple(handlerData, structs);
return intermediateHandlerAndStructTuple;
});
var intermediateHandlerAndStructTuples = matchingData.Where(x => x != null);
return intermediateHandlerAndStructTuples;
});
// Collect and generate the dictionary // Collect and generate the dictionary
context.RegisterSourceOutput(combinedResults, (ctx, result) => { context.RegisterSourceOutput(combinedResults, (ctx, result) => {
var (packetStructs, packetHandlerDatas) = result; var combinedInfo = result.ToList();
var combinedInfo = packetHandlerDatas.Select(handler => { if (combinedInfo.Count <= 0)
var respectiveStruct = packetStructs.FirstOrDefault(pStruct => return;
pStruct.structName.SequenceEqual(handler?.PacketStructHandlerData?.PacketStructSymbol.Name ??
string.Empty)); var packetHandlerData = combinedInfo.First()?.HandlerData;
return (handler, respectiveStruct);
});
var packetHandlerData = packetHandlerDatas.FirstOrDefault();
var usedValues = new List<long>(); var usedValues = new List<long>();
var highestValue = long.Parse(packetStructs.FirstOrDefault().enumMaxValue?.ToString()); var intermediatePacketStructData = combinedInfo.First()?.StructData;
if (intermediatePacketStructData?.EnumMaxValue == null)
return;
var highestValue = long.Parse(intermediatePacketStructData.EnumMaxValue);
var ms = new MemoryStream(); var ms = new MemoryStream();
var sw = new StreamWriter(ms, Encoding.UTF8); var sw = new StreamWriter(ms, Encoding.UTF8);
sw.AutoFlush = true; sw.AutoFlush = true;
var enumTypeString = packetStructs.FirstOrDefault().enumType?.ToDisplayString(); var enumTypeString = intermediatePacketStructData?.EnumTypeFullIdentifier;
var sessionTypeString = packetHandlerData?.PacketStructHandlerData?.SessionSymbol.ToDisplayString(); var sessionTypeString = packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier;
sw.WriteLine($$""" sw.WriteLine($$"""
using System; using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
public static class PacketHandlerMediator public static class PacketHandlerMediator
{ {
public async static Task Handle(IServiceProvider serviceProvider, byte[] data,{{enumTypeString}} opcode ,{{packetHandlerData?.PacketStructHandlerData?.SessionSymbol}} session, CancellationToken cancellationToken){ public async static Task Handle(IServiceProvider serviceProvider, byte[] data,{{enumTypeString}} opcode ,{{packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier}} session, CancellationToken cancellationToken){
switch(opcode) switch(opcode)
{ {
"""); """);
foreach (var ((handler, (symbol, structName, value, _, enumMember, _, implementsInterface)), i) in var valueTuples = combinedInfo.Select((value, i) => (value, i));
combinedInfo.Select((value, i)=> (value, i))) { foreach (var ((handlerData, packetStructData), i) in valueTuples) {
if (!implementsInterface) { var tempVal = long.Parse(packetStructData.EnumValue,
var diagnostic = Diagnostic.Create(_rpmGen001Diagnostic, symbol?.Locations.FirstOrDefault(),
structName);
ctx.ReportDiagnostic(diagnostic);
continue;
}
var tempVal = long.Parse(value?.ToString()!,
new NumberFormatInfo()); new NumberFormatInfo());
usedValues.Add(tempVal); usedValues.Add(tempVal);
sw.WriteLine($""" sw.WriteLine($"""
case {enumMember}: case {packetStructData.EnumMemberIdentifier}:
var packet = new {handler?.PacketStructHandlerData?.PacketStructSymbol.ToDisplayString()}(); var packet = new {handlerData.PacketHandlerIdentifier}();
packet.Deserialize(data); packet.Deserialize(data);
_ = {handler?.Symbol?.ToDisplayString()}.HandleAsync(packet, session, cancellationToken); _ = {handlerData.PacketHandlerIdentifier}.HandleAsync(packet, session, cancellationToken);
return; return;
"""); """);
} }
@ -247,7 +254,8 @@ public interface IPacketHandler<in TSession>
"""); """);
sw.Flush(); sw.Flush();
ctx.AddSource("PacketHandlerMediator.g.cs", SourceText.From(sw.BaseStream, Encoding.UTF8, canBeEmbedded: true)); ctx.AddSource("PacketHandlerMediator.g.cs",
SourceText.From(sw.BaseStream, Encoding.UTF8, canBeEmbedded: true));
sw.Close(); sw.Close();
ms.Close(); ms.Close();
var stringWriter = new StringWriter(); var stringWriter = new StringWriter();
@ -255,31 +263,63 @@ public interface IPacketHandler<in TSession>
idWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;"); idWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;");
idWriter.WriteLine("public static class ServiceExtensions{"); idWriter.WriteLine("public static class ServiceExtensions{");
idWriter.Indent++; idWriter.Indent++;
idWriter.WriteLine("public static void AddPacketHandlerServices(this IServiceCollection serviceCollection){"); idWriter.WriteLine(
"public static void AddPacketHandlerServices(this IServiceCollection serviceCollection){");
idWriter.Indent++; idWriter.Indent++;
idWriter.WriteLine("// PacketHandler Service Generation"); idWriter.WriteLine("// PacketHandler Service Generation");
foreach (var handlerData in packetHandlerDatas) { foreach (var (handlerData, _) in combinedInfo) {
idWriter.WriteLine($"serviceCollection.AddScoped<{handlerData?.Symbol?.ToDisplayString()}>();"); idWriter.WriteLine($"serviceCollection.AddScoped<{handlerData.PacketHandlerIdentifier}>();");
} }
idWriter.Indent--; idWriter.Indent--;
idWriter.WriteLine("}"); idWriter.WriteLine("}");
idWriter.Indent--; idWriter.Indent--;
idWriter.WriteLine("}"); idWriter.WriteLine("}");
ctx.AddSource("ServiceExtensions.g.cs", SourceText.From(stringWriter.ToString(),Encoding.UTF8)); ctx.AddSource("ServiceExtensions.g.cs", SourceText.From(stringWriter.ToString(), Encoding.UTF8));
}); });
} }
private static IntermediatePacketHandlerData? TransformPacketHandlers(GeneratorSyntaxContext syntaxContext,
CancellationToken cancellationToken) {
var classDeclaration = (ClassDeclarationSyntax)syntaxContext.Node;
var model = syntaxContext.SemanticModel;
var symbol =
ModelExtensions.GetDeclaredSymbol(model, classDeclaration, cancellationToken: cancellationToken) as
INamedTypeSymbol;
var packetStruct = (symbol.Interfaces.Select(interfaceSyntax => {
if (!interfaceSyntax.Name.Equals("IPacketHandler")) {
return null;
}
if (interfaceSyntax.TypeArguments.Length != 2) return null;
var genericStructArgument = interfaceSyntax.TypeArguments[0];
var genericSessionArgument = interfaceSyntax.TypeArguments[1];
return new IntermediatePacketStructHandlerData(genericStructArgument.ToDisplayString(),
genericSessionArgument.ToDisplayString());
}) ?? throw new InvalidOperationException("1")).FirstOrDefault(x => x != null);
if (packetStruct == null) return null;
var intermediatePacketHandlerData = new IntermediatePacketHandlerData(symbol.ToDisplayString(), packetStruct);
return intermediatePacketHandlerData;
}
// https://andrewlock.net/creating-a-source-generator-part-11-implementing-an-interceptor-with-a-source-generator/ // https://andrewlock.net/creating-a-source-generator-part-11-implementing-an-interceptor-with-a-source-generator/
private static CandidateInvocation? InterceptorTransform(GeneratorSyntaxContext context, CancellationToken cancellationToken) { private static CandidateInvocation? InterceptorTransform(GeneratorSyntaxContext context,
CancellationToken cancellationToken) {
if (context.Node is InvocationExpressionSyntax { if (context.Node is InvocationExpressionSyntax {
Expression: MemberAccessExpressionSyntax { Name: { } nameSyntax } Expression: MemberAccessExpressionSyntax { Name: { } nameSyntax }
} invocation && } invocation &&
context.SemanticModel.GetOperation(context.Node, cancellationToken) is IInvocationOperation targetOperation && context.SemanticModel.GetOperation(context.Node,
cancellationToken) is IInvocationOperation targetOperation &&
targetOperation.TargetMethod is targetOperation.TargetMethod is
{ Name: "AddPacketHandlerServices", ContainingNamespace: { Name: "RaiNote.PacketMediator" } } { Name: "AddPacketHandlerServices", ContainingNamespace: { Name: "RaiNote.PacketMediator" } }
) { ) {
#pragma warning disable RSEXPERIMENTAL002 // / Experimental interceptable location API #pragma warning disable RSEXPERIMENTAL002 // / Experimental interceptable location API
if (context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken: cancellationToken) is { } location) { if (context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken: cancellationToken) is
{ } location) {
// Return the location details and the full type details // Return the location details and the full type details
return new CandidateInvocation(location); return new CandidateInvocation(location);
} }