refactor: packet differentiation + packet id mapping

This commit is contained in:
Timothy Schenk 2024-02-07 09:22:04 +01:00
parent 4982db6609
commit 60247ce42c
Signed by: rainote
SSH key fingerprint: SHA256:pnkNSDwpAnaip00xaZlVFHKKsS7T8UtOomMzvs0yITE
7 changed files with 49 additions and 47 deletions

View file

@ -1,7 +1,6 @@
// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License.
using System.Net.Sockets;
using System.Reflection;
using Continuity.AuthServer.Packets;
using MassTransit.Mediator;
using Microsoft.Extensions.Logging;
@ -14,13 +13,16 @@ namespace Continuity.AuthServer;
public class AuthSession : TcpSession
{
private readonly ILogger<AuthSession> _logger;
private readonly PacketDistributorService<OperationCode, AuthSession> _distributorService;
private readonly IMediator _mediator;
public AuthSession(TcpServer
server, IMediator mediator, ILogger<AuthSession> logger) : base(server)
server, IMediator mediator, ILogger<AuthSession> logger,
PacketDistributorService<OperationCode, AuthSession> distributorService) : base(server)
{
_mediator = mediator;
_logger = logger;
_distributorService = distributorService;
}
public Guid AccountId { get; set; }
@ -33,15 +35,7 @@ public class AuthSession : TcpSession
public Task SendAsync(IOutgoingPacket packet)
{
var type = packet.GetType();
_logger.LogInformation("Packet of type {Type} is being serialized", type.Name);
var packetIdAttribute = type.GetCustomAttribute<WonderkingPacketIdAttribute>();
if (packetIdAttribute == null)
{
return Task.CompletedTask;
}
var opcode = packetIdAttribute.Code;
var opcode = _distributorService.GetOperationCodeByPacketType(packet);
Span<byte> packetData = packet.Serialize();
var length = (ushort)(packetData.Length + 8);

View file

@ -75,7 +75,7 @@ public class LoginHandler : IPacketHandler<LoginInfoPacket, AuthSession>
}
_logger.LogInformation("LoginResponsePacket: {@LoginResponsePacket}", loginResponsePacket);
_ = session?.SendAsync(loginResponsePacket);
_ = session.SendAsync(loginResponsePacket);
}
private static async Task<byte[]> GetPasswordHashAsync(string password, byte[] salt, Guid userId)

View file

@ -110,10 +110,8 @@ builder.Services.AddSingleton<ILoggerFactory>(loggerFactory);
builder.Services.AddSingleton(provider =>
new PacketDistributorService<OperationCode, AuthSession>(
provider.GetRequiredService<IServiceProvider>(),
[
Assembly.GetAssembly(typeof(LoginHandler)),
Assembly.GetAssembly(typeof(OperationCode)),
]
new List<Assembly> { Assembly.GetAssembly(typeof(OperationCode)) }.AsReadOnly(),
new List<Assembly> { Assembly.GetAssembly(typeof(LoginHandler)) }.AsReadOnly()
));
builder.Services.AddSingleton<ItemObjectPoolService>();
@ -125,7 +123,8 @@ builder.Services.AddHostedService(provider =>
provider.GetService<ItemObjectPoolService>() ?? throw new InvalidOperationException());
builder.Services.AddHostedService(provider =>
provider.GetService<PacketDistributorService<OperationCode, AuthSession>>() ?? throw new InvalidOperationException());
provider.GetService<PacketDistributorService<OperationCode, AuthSession>>() ??
throw new InvalidOperationException());
builder.Services.AddMassTransit(x =>
{

View file

@ -8,7 +8,6 @@ using DotNext.Collections.Generic;
using DotNext.Linq.Expressions;
using DotNext.Metaprogramming;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualBasic.CompilerServices;
namespace Rai.PacketMediator;
@ -16,14 +15,15 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
{
private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel;
private readonly Assembly[] _sourcesContainingPackets;
private readonly ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation;
private readonly ImmutableDictionary<TPacketIdEnum,
Func<byte[], IIncomingPacket>> _deserializationMap;
public ImmutableDictionary<Type, TPacketIdEnum> PacketIdMap { get; }
public PacketDistributor(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets)
IEnumerable<Assembly> sourcesContainingPackets, IEnumerable<Assembly> sourcesContainingPacketHandlers)
{
_channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(new UnboundedChannelOptions
{
@ -31,20 +31,14 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
SingleReader = false,
SingleWriter = false
});
_sourcesContainingPackets = sourcesContainingPackets;
var packetDictionary = GetAllPackets();
var containingPackets = sourcesContainingPackets as Assembly[] ?? sourcesContainingPackets.ToArray();
var allIncomingPackets = GetAllPackets(containingPackets, typeof(IIncomingPacket));
var allOutgoingPackets = GetAllPackets(containingPackets, typeof(IOutgoingPacket));
if (packetDictionary is { Count: 0 })
{
throw new IncompleteInitialization();
}
var packetHandlers = GetAllPacketHandlersWithId(sourcesContainingPacketHandlers);
var packetHandlers = GetAllPacketHandlersWithId();
if (packetHandlers is { Count: 0 })
{
throw new IncompleteInitialization();
}
this.PacketIdMap = allOutgoingPackets.Select(x => new { PacketId = x.Key, Type = x.Value })
.ToImmutableDictionary(x => x.Type, x => x.PacketId);
var tempDeserializationMap =
new ConcurrentDictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>();
@ -56,7 +50,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
packetHandlerPair.Value);
_packetHandlersInstantiation.TryAdd(packetHandlerPair.Key, packetHandler as IPacketHandler<TSession>);
});
packetDictionary.ForEach(packetsType =>
allIncomingPackets.ForEach(packetsType =>
{
var lambda = CodeGenerator.Lambda<Func<byte[], IIncomingPacket>>(fun =>
{
@ -75,23 +69,25 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
_deserializationMap = tempDeserializationMap.ToImmutableDictionary();
}
private Dictionary<TPacketIdEnum, Type> GetAllPackets()
private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPackets(
IEnumerable<Assembly> sourcesContainingPackets, Type packetType)
{
var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes()
var packetsWithId = sourcesContainingPackets.SelectMany(a => a.GetTypes()
.Where(type => type is { IsInterface: false, IsAbstract: false } &&
type.GetInterfaces().Contains(typeof(IIncomingPacket))
type.GetInterfaces().Contains(packetType)
&& type.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any()
))
.Select(type =>
new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() })
.ToDictionary(item => item.Attribute!.Code, item => item.Type);
.Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.Attribute!.Code, x.Type));
return packetsWithId;
}
private Dictionary<TPacketIdEnum, Type> GetAllPacketHandlersWithId()
private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPacketHandlersWithId(
IEnumerable<Assembly> sourcesContainingPacketHandlers)
{
var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes()
var packetHandlersWithId = sourcesContainingPacketHandlers.SelectMany(assembly => assembly.GetTypes()
.Where(t =>
t is { IsClass: true, IsAbstract: false } && Array.Exists(t
.GetInterfaces(), i =>
@ -108,7 +104,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>()
}))
.Where(x => x.PacketId != null)
.ToDictionary(x => x.PacketId!.Code, x => x.Type);
.Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.PacketId!.Code, x.Type));
return packetHandlersWithId;
}
@ -118,7 +114,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
await _channel.Writer.WriteAsync((packetData, operationCode, session));
}
public async Task DequeueRawPacketAsync(CancellationToken cancellationToken)
public async Task DequeuePacketAsync(CancellationToken cancellationToken)
{
while (await _channel.Reader.WaitToReadAsync(cancellationToken))
{

View file

@ -10,14 +10,15 @@ public class PacketDistributorService<TPacketIdEnum, TSession> : Microsoft.Exten
private readonly PacketDistributor<TPacketIdEnum, TSession> _packetDistributor;
public PacketDistributorService(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets)
IEnumerable<Assembly> sourcesContainingPackets, IEnumerable<Assembly> sourcesContainingPacketHandlers)
{
_packetDistributor = new PacketDistributor<TPacketIdEnum, TSession>(serviceProvider, sourcesContainingPackets);
_packetDistributor = new PacketDistributor<TPacketIdEnum, TSession>(serviceProvider, sourcesContainingPackets,
sourcesContainingPacketHandlers);
}
public Task StartAsync(CancellationToken cancellationToken)
{
return _packetDistributor.DequeueRawPacketAsync(cancellationToken);
return _packetDistributor.DequeuePacketAsync(cancellationToken);
}
public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
@ -29,4 +30,16 @@ public class PacketDistributorService<TPacketIdEnum, TSession> : Microsoft.Exten
{
return Task.CompletedTask;
}
public TPacketIdEnum GetOperationCodeByPacketType(IPacket packet)
{
var type = packet.GetType();
this._packetDistributor.PacketIdMap.TryGetValue(type, out var value);
if (value is null)
{
throw new ArgumentOutOfRangeException(type.Name);
}
return value;
}
}

View file

@ -1,2 +1,2 @@
#!sh
docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity .
docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity-auth .

View file

@ -1,7 +1,7 @@
services:
server:
container_name: continuity-server
image: continuity:latest
container_name: continuity-auth
image: continuity-auth:latest
restart: always
depends_on:
- db