PacketMediator/RaiNote.PacketMediator/PacketMediatorGenerator.cs

344 lines
17 KiB
C#
Raw Normal View History

2025-01-17 08:53:44 +00:00
// Licensed to Timothy Schenk under the Apache 2.0 License.
2025-01-18 16:47:36 +00:00
using System.CodeDom.Compiler;
2025-01-17 08:53:44 +00:00
using System.Globalization;
using System.Text;
using Microsoft.CodeAnalysis;
2025-01-18 16:47:36 +00:00
using Microsoft.CodeAnalysis.CSharp;
2025-01-17 08:53:44 +00:00
using Microsoft.CodeAnalysis.CSharp.Syntax;
2025-01-18 16:47:36 +00:00
using Microsoft.CodeAnalysis.Operations;
2025-01-17 08:53:44 +00:00
using Microsoft.CodeAnalysis.Text;
namespace RaiNote.PacketMediator;
[Generator]
public class PacketMediatorGenerator : IIncrementalGenerator {
private readonly DiagnosticDescriptor _rpmGen001Diagnostic = new(
id: "RPMGen001",
title: "Struct does not implement required interface",
messageFormat:
"The struct '{0}' must implement at least one of: 'IPacket', 'IIncomingPacket', 'IOutgoingPacket', 'IBidirectionalPacket'",
category: "SourceGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);
public void Initialize(IncrementalGeneratorInitializationContext context) {
context.RegisterPostInitializationOutput(ctx => {
2025-01-18 20:32:43 +00:00
ctx.AddSource("PacketMediatorStatic.g.cs",
SourceText.From("""
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace RaiNote.PacketMediator;
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false)]
public abstract class PacketIdAttribute<TPacketIdEnum> : Attribute where TPacketIdEnum : Enum
{
protected PacketIdAttribute(TPacketIdEnum code)
{
Code = code;
}
public TPacketIdEnum Code { get; }
}
public interface IPacket;
public interface IIncomingPacket : IPacket
{
public void Deserialize(byte[] data);
}
public interface IOutgoingPacket : IPacket
{
public byte[] Serialize();
}
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
public interface IBidirectionalPacket : IOutgoingPacket, IIncomingPacket;
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandler<TSession>
where TIncomingPacket : IIncomingPacket
{
async Task<bool> IPacketHandler<TSession>.TryHandleAsync(IIncomingPacket packet, TSession session,
CancellationToken cancellationToken)
{
if (packet is not TIncomingPacket tPacket)
{
return false;
}
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync));
activity?.AddTag("Handler", ToString());
activity?.AddTag("Packet", packet.ToString());
await HandleAsync(tPacket, session, cancellationToken);
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
return true;
}
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken);
}
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
public interface IPacketHandler<in TSession>
{
Task<bool> TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken);
}
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
""", Encoding.UTF8));
2025-01-17 08:53:44 +00:00
});
// Find all struct declarations
var structsWithAttributes = context.SyntaxProvider
.CreateSyntaxProvider(
2025-01-18 20:32:43 +00:00
predicate: (node, _) => node is StructDeclarationSyntax { AttributeLists.Count: > 0 },
2025-01-17 08:53:44 +00:00
transform: (syntaxContext, cancellationToken) => {
var structDeclaration = (StructDeclarationSyntax)syntaxContext.Node;
var model = syntaxContext.SemanticModel;
var symbol =
2025-01-18 16:47:36 +00:00
ModelExtensions.GetDeclaredSymbol(model, structDeclaration,
cancellationToken: cancellationToken) as
2025-01-17 08:53:44 +00:00
INamedTypeSymbol;
var requiredInterfaces = new[] {
"IPacket", "IIncomingPacket", "IOutgoingPacket", "IBidirectionalPacket"
};
var implementsInterface = symbol != null && symbol.AllInterfaces
.Any(i => requiredInterfaces.Contains(i.Name, StringComparer.Ordinal));
2025-01-18 20:32:43 +00:00
if (!implementsInterface) {
// TODO: https://github.com/dotnet/roslyn/blob/main/docs/features/incremental-generators.cookbook.md#issue-diagnostics
// or Analyzer
var diagnostic = Diagnostic.Create(_rpmGen001Diagnostic, symbol?.Locations.First(),
symbol?.ToDisplayString());
}
2025-01-17 08:53:44 +00:00
// Check for the marker attribute
var attribute = symbol?.GetAttributes()
.FirstOrDefault(attr => {
var attrClass = attr.AttributeClass;
while (attrClass != null) {
if (string.Equals(attrClass.Name, "PacketIdAttribute", StringComparison.Ordinal) &&
string.Equals(attrClass.ContainingNamespace.ToDisplayString(),
"RaiNote.PacketMediator", StringComparison.Ordinal)) {
return true;
}
attrClass = attrClass.BaseType;
}
return false;
});
if (attribute == null) {
2025-01-18 20:32:43 +00:00
return null;
2025-01-17 08:53:44 +00:00
}
var attributeConstructorArgument = attribute.ConstructorArguments[0];
var enumType = attributeConstructorArgument.Type;
var enumValue = attributeConstructorArgument.Value;
2025-01-18 16:47:36 +00:00
2025-01-17 08:53:44 +00:00
var enumMember = enumType?.GetMembers()
.OfType<IFieldSymbol>()
.FirstOrDefault(f => f.ConstantValue?.Equals(enumValue) == true);
2025-01-18 16:47:36 +00:00
2025-01-17 08:53:44 +00:00
var enumMaxValue = enumType?.GetMembers()
.OfType<IFieldSymbol>().Max(x => x.ConstantValue);
2025-01-18 16:47:36 +00:00
2025-01-18 20:32:43 +00:00
if (symbol == null || enumMember == null || enumMaxValue == null || enumType == null ||
enumValue == null)
2025-01-17 08:53:44 +00:00
return null;
2025-01-18 20:32:43 +00:00
var intermediatePacketStructData = new IntermediatePacketStructData(symbol.Locations.First(),
symbol.ToDisplayString(),
enumValue.ToString(),
enumType.ToDisplayString(),
enumMember.ToDisplayString(),
enumMaxValue.ToString(), implementsInterface);
2025-01-18 16:47:36 +00:00
2025-01-18 20:32:43 +00:00
return intermediatePacketStructData;
})
.Where(result => result != null);
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
var packetHandlerValues = context.SyntaxProvider.CreateSyntaxProvider(
predicate: (node, _) => node is ClassDeclarationSyntax,
TransformPacketHandlers).Where(result => result != null);
2025-01-17 08:53:44 +00:00
2025-01-18 16:47:36 +00:00
var location =
context.SyntaxProvider.CreateSyntaxProvider(predicate: static (node, _) => InterceptorPredicate(node),
transform: static (context, ct) => InterceptorTransform(context, ct))
.Where(candidate => candidate is not null);
2025-01-18 20:32:43 +00:00
var combinedResults = structsWithAttributes.Collect().Combine(packetHandlerValues.Collect()).Select(
static (tuple, cancellationToken) => {
var (structDatas, handlerDatas) = tuple;
var matchingData = handlerDatas.Select(handlerData => {
if (handlerData == null)
return null;
var structs = structDatas.Where(sData => {
var equals =
sData != null && sData.PacketStructFullIdentifier.Equals(handlerData.PacketStructHandlerData
?.PacketStructFullIdentifier);
return equals;
}).FirstOrDefault();
if (structs == null)
return null;
var intermediateHandlerAndStructTuple = new IntermediateHandlerAndStructTuple(handlerData, structs);
return intermediateHandlerAndStructTuple;
});
var intermediateHandlerAndStructTuples = matchingData.Where(x => x != null);
return intermediateHandlerAndStructTuples;
});
2025-01-17 08:53:44 +00:00
// Collect and generate the dictionary
context.RegisterSourceOutput(combinedResults, (ctx, result) => {
2025-01-18 20:32:43 +00:00
var combinedInfo = result.ToList();
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
if (combinedInfo.Count <= 0)
return;
var packetHandlerData = combinedInfo.First()?.HandlerData;
2025-01-17 08:53:44 +00:00
var usedValues = new List<long>();
2025-01-18 20:32:43 +00:00
var intermediatePacketStructData = combinedInfo.First()?.StructData;
if (intermediatePacketStructData?.EnumMaxValue == null)
return;
var highestValue = long.Parse(intermediatePacketStructData.EnumMaxValue);
2025-01-18 16:47:36 +00:00
var ms = new MemoryStream();
var sw = new StreamWriter(ms, Encoding.UTF8);
sw.AutoFlush = true;
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
var enumTypeString = intermediatePacketStructData?.EnumTypeFullIdentifier;
var sessionTypeString = packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier;
2025-01-18 16:47:36 +00:00
sw.WriteLine($$"""
2025-01-18 20:32:43 +00:00
using System;
using System.Threading;
using System.Threading.Tasks;
public static class PacketHandlerMediator
{
public async static Task Handle(IServiceProvider serviceProvider, byte[] data,{{enumTypeString}} opcode ,{{packetHandlerData?.PacketStructHandlerData?.SessionFullIdentifier}} session, CancellationToken cancellationToken){
switch(opcode)
{
""");
2025-01-17 08:53:44 +00:00
2025-01-18 20:32:43 +00:00
var valueTuples = combinedInfo.Select((value, i) => (value, i));
foreach (var ((handlerData, packetStructData), i) in valueTuples) {
var tempVal = long.Parse(packetStructData.EnumValue,
2025-01-17 08:53:44 +00:00
new NumberFormatInfo());
usedValues.Add(tempVal);
2025-01-18 16:47:36 +00:00
sw.WriteLine($"""
2025-01-18 20:32:43 +00:00
case {packetStructData.EnumMemberIdentifier}:
var packet = new {handlerData.PacketHandlerIdentifier}();
2025-01-18 16:47:36 +00:00
packet.Deserialize(data);
2025-01-18 20:32:43 +00:00
_ = {handlerData.PacketHandlerIdentifier}.HandleAsync(packet, session, cancellationToken);
2025-01-18 16:47:36 +00:00
return;
""");
2025-01-17 08:53:44 +00:00
}
2025-01-18 16:47:36 +00:00
// Forced jumptable on asm generation
2025-01-17 08:53:44 +00:00
for (long i = 0; i <= highestValue; i++) {
2025-01-18 16:47:36 +00:00
if (!usedValues.Contains(i)) {
sw.WriteLine($"""
case (({enumTypeString}){i}):
_ = StubHandler.HandleAsync(data, opcode, session, cancellationToken);
return;
""");
}
2025-01-17 08:53:44 +00:00
}
2025-01-18 16:47:36 +00:00
// TODO: allow overriding
sw.WriteLine($$"""
}
}
}
public static class StubHandler {
public static async Task HandleAsync(byte[] data,{{enumTypeString}} opcode, {{sessionTypeString}} session, CancellationToken cancellationToken) {
// Stub method
}
}
""");
sw.Flush();
2025-01-18 20:32:43 +00:00
ctx.AddSource("PacketHandlerMediator.g.cs",
SourceText.From(sw.BaseStream, Encoding.UTF8, canBeEmbedded: true));
2025-01-18 16:47:36 +00:00
sw.Close();
ms.Close();
var stringWriter = new StringWriter();
var idWriter = new IndentedTextWriter(stringWriter);
idWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;");
idWriter.WriteLine("public static class ServiceExtensions{");
idWriter.Indent++;
2025-01-18 20:32:43 +00:00
idWriter.WriteLine(
"public static void AddPacketHandlerServices(this IServiceCollection serviceCollection){");
2025-01-18 16:47:36 +00:00
idWriter.Indent++;
idWriter.WriteLine("// PacketHandler Service Generation");
2025-01-18 20:32:43 +00:00
foreach (var (handlerData, _) in combinedInfo) {
idWriter.WriteLine($"serviceCollection.AddScoped<{handlerData.PacketHandlerIdentifier}>();");
2025-01-18 16:47:36 +00:00
}
2025-01-18 20:32:43 +00:00
2025-01-18 16:47:36 +00:00
idWriter.Indent--;
idWriter.WriteLine("}");
idWriter.Indent--;
idWriter.WriteLine("}");
2025-01-18 20:32:43 +00:00
ctx.AddSource("ServiceExtensions.g.cs", SourceText.From(stringWriter.ToString(), Encoding.UTF8));
2025-01-17 08:53:44 +00:00
});
}
2025-01-18 16:47:36 +00:00
2025-01-18 20:32:43 +00:00
private static IntermediatePacketHandlerData? TransformPacketHandlers(GeneratorSyntaxContext syntaxContext,
CancellationToken cancellationToken) {
var classDeclaration = (ClassDeclarationSyntax)syntaxContext.Node;
var model = syntaxContext.SemanticModel;
var symbol =
ModelExtensions.GetDeclaredSymbol(model, classDeclaration, cancellationToken: cancellationToken) as
INamedTypeSymbol;
var packetStruct = (symbol.Interfaces.Select(interfaceSyntax => {
if (!interfaceSyntax.Name.Equals("IPacketHandler")) {
return null;
}
if (interfaceSyntax.TypeArguments.Length != 2) return null;
var genericStructArgument = interfaceSyntax.TypeArguments[0];
var genericSessionArgument = interfaceSyntax.TypeArguments[1];
return new IntermediatePacketStructHandlerData(genericStructArgument.ToDisplayString(),
genericSessionArgument.ToDisplayString());
}) ?? throw new InvalidOperationException("1")).FirstOrDefault(x => x != null);
if (packetStruct == null) return null;
var intermediatePacketHandlerData = new IntermediatePacketHandlerData(symbol.ToDisplayString(), packetStruct);
return intermediatePacketHandlerData;
}
2025-01-18 16:47:36 +00:00
// https://andrewlock.net/creating-a-source-generator-part-11-implementing-an-interceptor-with-a-source-generator/
2025-01-18 20:32:43 +00:00
private static CandidateInvocation? InterceptorTransform(GeneratorSyntaxContext context,
CancellationToken cancellationToken) {
2025-01-18 16:47:36 +00:00
if (context.Node is InvocationExpressionSyntax {
Expression: MemberAccessExpressionSyntax { Name: { } nameSyntax }
} invocation &&
2025-01-18 20:32:43 +00:00
context.SemanticModel.GetOperation(context.Node,
cancellationToken) is IInvocationOperation targetOperation &&
2025-01-18 16:47:36 +00:00
targetOperation.TargetMethod is
{ Name: "AddPacketHandlerServices", ContainingNamespace: { Name: "RaiNote.PacketMediator" } }
) {
#pragma warning disable RSEXPERIMENTAL002 // / Experimental interceptable location API
2025-01-18 20:32:43 +00:00
if (context.SemanticModel.GetInterceptableLocation(invocation, cancellationToken: cancellationToken) is
{ } location) {
2025-01-18 16:47:36 +00:00
// Return the location details and the full type details
return new CandidateInvocation(location);
}
#pragma warning restore RSEXPERIMENTAL002
}
return null;
}
private static bool InterceptorPredicate(SyntaxNode node) {
return node is InvocationExpressionSyntax {
Expression: MemberAccessExpressionSyntax {
Name.Identifier.ValueText: "AddPacketHandlerServices"
}
};
}
#pragma warning disable RSEXPERIMENTAL002 // / Experimental interceptable location API
public record CandidateInvocation(InterceptableLocation Location);
#pragma warning restore RSEXPERIMENTAL002
2025-01-17 08:53:44 +00:00
}