refactor: packet differentiation + packet id mapping

This commit is contained in:
Timothy Schenk 2024-02-07 09:22:04 +01:00
parent 4982db6609
commit 60247ce42c
Signed by: rainote
SSH key fingerprint: SHA256:pnkNSDwpAnaip00xaZlVFHKKsS7T8UtOomMzvs0yITE
7 changed files with 49 additions and 47 deletions

View file

@ -1,7 +1,6 @@
// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License. // Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License.
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection;
using Continuity.AuthServer.Packets; using Continuity.AuthServer.Packets;
using MassTransit.Mediator; using MassTransit.Mediator;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -14,13 +13,16 @@ namespace Continuity.AuthServer;
public class AuthSession : TcpSession public class AuthSession : TcpSession
{ {
private readonly ILogger<AuthSession> _logger; private readonly ILogger<AuthSession> _logger;
private readonly PacketDistributorService<OperationCode, AuthSession> _distributorService;
private readonly IMediator _mediator; private readonly IMediator _mediator;
public AuthSession(TcpServer public AuthSession(TcpServer
server, IMediator mediator, ILogger<AuthSession> logger) : base(server) server, IMediator mediator, ILogger<AuthSession> logger,
PacketDistributorService<OperationCode, AuthSession> distributorService) : base(server)
{ {
_mediator = mediator; _mediator = mediator;
_logger = logger; _logger = logger;
_distributorService = distributorService;
} }
public Guid AccountId { get; set; } public Guid AccountId { get; set; }
@ -33,15 +35,7 @@ public class AuthSession : TcpSession
public Task SendAsync(IOutgoingPacket packet) public Task SendAsync(IOutgoingPacket packet)
{ {
var type = packet.GetType(); var opcode = _distributorService.GetOperationCodeByPacketType(packet);
_logger.LogInformation("Packet of type {Type} is being serialized", type.Name);
var packetIdAttribute = type.GetCustomAttribute<WonderkingPacketIdAttribute>();
if (packetIdAttribute == null)
{
return Task.CompletedTask;
}
var opcode = packetIdAttribute.Code;
Span<byte> packetData = packet.Serialize(); Span<byte> packetData = packet.Serialize();
var length = (ushort)(packetData.Length + 8); var length = (ushort)(packetData.Length + 8);

View file

@ -75,7 +75,7 @@ public class LoginHandler : IPacketHandler<LoginInfoPacket, AuthSession>
} }
_logger.LogInformation("LoginResponsePacket: {@LoginResponsePacket}", loginResponsePacket); _logger.LogInformation("LoginResponsePacket: {@LoginResponsePacket}", loginResponsePacket);
_ = session?.SendAsync(loginResponsePacket); _ = session.SendAsync(loginResponsePacket);
} }
private static async Task<byte[]> GetPasswordHashAsync(string password, byte[] salt, Guid userId) private static async Task<byte[]> GetPasswordHashAsync(string password, byte[] salt, Guid userId)

View file

@ -110,10 +110,8 @@ builder.Services.AddSingleton<ILoggerFactory>(loggerFactory);
builder.Services.AddSingleton(provider => builder.Services.AddSingleton(provider =>
new PacketDistributorService<OperationCode, AuthSession>( new PacketDistributorService<OperationCode, AuthSession>(
provider.GetRequiredService<IServiceProvider>(), provider.GetRequiredService<IServiceProvider>(),
[ new List<Assembly> { Assembly.GetAssembly(typeof(OperationCode)) }.AsReadOnly(),
Assembly.GetAssembly(typeof(LoginHandler)), new List<Assembly> { Assembly.GetAssembly(typeof(LoginHandler)) }.AsReadOnly()
Assembly.GetAssembly(typeof(OperationCode)),
]
)); ));
builder.Services.AddSingleton<ItemObjectPoolService>(); builder.Services.AddSingleton<ItemObjectPoolService>();
@ -125,7 +123,8 @@ builder.Services.AddHostedService(provider =>
provider.GetService<ItemObjectPoolService>() ?? throw new InvalidOperationException()); provider.GetService<ItemObjectPoolService>() ?? throw new InvalidOperationException());
builder.Services.AddHostedService(provider => builder.Services.AddHostedService(provider =>
provider.GetService<PacketDistributorService<OperationCode, AuthSession>>() ?? throw new InvalidOperationException()); provider.GetService<PacketDistributorService<OperationCode, AuthSession>>() ??
throw new InvalidOperationException());
builder.Services.AddMassTransit(x => builder.Services.AddMassTransit(x =>
{ {

View file

@ -8,7 +8,6 @@ using DotNext.Collections.Generic;
using DotNext.Linq.Expressions; using DotNext.Linq.Expressions;
using DotNext.Metaprogramming; using DotNext.Metaprogramming;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualBasic.CompilerServices;
namespace Rai.PacketMediator; namespace Rai.PacketMediator;
@ -16,14 +15,15 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
{ {
private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel; private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel;
private readonly Assembly[] _sourcesContainingPackets;
private readonly ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation; private readonly ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation;
private readonly ImmutableDictionary<TPacketIdEnum, private readonly ImmutableDictionary<TPacketIdEnum,
Func<byte[], IIncomingPacket>> _deserializationMap; Func<byte[], IIncomingPacket>> _deserializationMap;
public ImmutableDictionary<Type, TPacketIdEnum> PacketIdMap { get; }
public PacketDistributor(IServiceProvider serviceProvider, public PacketDistributor(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets) IEnumerable<Assembly> sourcesContainingPackets, IEnumerable<Assembly> sourcesContainingPacketHandlers)
{ {
_channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(new UnboundedChannelOptions _channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(new UnboundedChannelOptions
{ {
@ -31,20 +31,14 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
SingleReader = false, SingleReader = false,
SingleWriter = false SingleWriter = false
}); });
_sourcesContainingPackets = sourcesContainingPackets; var containingPackets = sourcesContainingPackets as Assembly[] ?? sourcesContainingPackets.ToArray();
var packetDictionary = GetAllPackets(); var allIncomingPackets = GetAllPackets(containingPackets, typeof(IIncomingPacket));
var allOutgoingPackets = GetAllPackets(containingPackets, typeof(IOutgoingPacket));
if (packetDictionary is { Count: 0 }) var packetHandlers = GetAllPacketHandlersWithId(sourcesContainingPacketHandlers);
{
throw new IncompleteInitialization();
}
var packetHandlers = GetAllPacketHandlersWithId(); this.PacketIdMap = allOutgoingPackets.Select(x => new { PacketId = x.Key, Type = x.Value })
.ToImmutableDictionary(x => x.Type, x => x.PacketId);
if (packetHandlers is { Count: 0 })
{
throw new IncompleteInitialization();
}
var tempDeserializationMap = var tempDeserializationMap =
new ConcurrentDictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>(); new ConcurrentDictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>();
@ -56,7 +50,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
packetHandlerPair.Value); packetHandlerPair.Value);
_packetHandlersInstantiation.TryAdd(packetHandlerPair.Key, packetHandler as IPacketHandler<TSession>); _packetHandlersInstantiation.TryAdd(packetHandlerPair.Key, packetHandler as IPacketHandler<TSession>);
}); });
packetDictionary.ForEach(packetsType => allIncomingPackets.ForEach(packetsType =>
{ {
var lambda = CodeGenerator.Lambda<Func<byte[], IIncomingPacket>>(fun => var lambda = CodeGenerator.Lambda<Func<byte[], IIncomingPacket>>(fun =>
{ {
@ -75,23 +69,25 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
_deserializationMap = tempDeserializationMap.ToImmutableDictionary(); _deserializationMap = tempDeserializationMap.ToImmutableDictionary();
} }
private Dictionary<TPacketIdEnum, Type> GetAllPackets() private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPackets(
IEnumerable<Assembly> sourcesContainingPackets, Type packetType)
{ {
var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes() var packetsWithId = sourcesContainingPackets.SelectMany(a => a.GetTypes()
.Where(type => type is { IsInterface: false, IsAbstract: false } && .Where(type => type is { IsInterface: false, IsAbstract: false } &&
type.GetInterfaces().Contains(typeof(IIncomingPacket)) type.GetInterfaces().Contains(packetType)
&& type.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any() && type.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any()
)) ))
.Select(type => .Select(type =>
new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() }) new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() })
.ToDictionary(item => item.Attribute!.Code, item => item.Type); .Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.Attribute!.Code, x.Type));
return packetsWithId; return packetsWithId;
} }
private Dictionary<TPacketIdEnum, Type> GetAllPacketHandlersWithId() private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPacketHandlersWithId(
IEnumerable<Assembly> sourcesContainingPacketHandlers)
{ {
var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes() var packetHandlersWithId = sourcesContainingPacketHandlers.SelectMany(assembly => assembly.GetTypes()
.Where(t => .Where(t =>
t is { IsClass: true, IsAbstract: false } && Array.Exists(t t is { IsClass: true, IsAbstract: false } && Array.Exists(t
.GetInterfaces(), i => .GetInterfaces(), i =>
@ -108,7 +104,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() .GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>()
})) }))
.Where(x => x.PacketId != null) .Where(x => x.PacketId != null)
.ToDictionary(x => x.PacketId!.Code, x => x.Type); .Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.PacketId!.Code, x.Type));
return packetHandlersWithId; return packetHandlersWithId;
} }
@ -118,7 +114,7 @@ public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : En
await _channel.Writer.WriteAsync((packetData, operationCode, session)); await _channel.Writer.WriteAsync((packetData, operationCode, session));
} }
public async Task DequeueRawPacketAsync(CancellationToken cancellationToken) public async Task DequeuePacketAsync(CancellationToken cancellationToken)
{ {
while (await _channel.Reader.WaitToReadAsync(cancellationToken)) while (await _channel.Reader.WaitToReadAsync(cancellationToken))
{ {

View file

@ -10,14 +10,15 @@ public class PacketDistributorService<TPacketIdEnum, TSession> : Microsoft.Exten
private readonly PacketDistributor<TPacketIdEnum, TSession> _packetDistributor; private readonly PacketDistributor<TPacketIdEnum, TSession> _packetDistributor;
public PacketDistributorService(IServiceProvider serviceProvider, public PacketDistributorService(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets) IEnumerable<Assembly> sourcesContainingPackets, IEnumerable<Assembly> sourcesContainingPacketHandlers)
{ {
_packetDistributor = new PacketDistributor<TPacketIdEnum, TSession>(serviceProvider, sourcesContainingPackets); _packetDistributor = new PacketDistributor<TPacketIdEnum, TSession>(serviceProvider, sourcesContainingPackets,
sourcesContainingPacketHandlers);
} }
public Task StartAsync(CancellationToken cancellationToken) public Task StartAsync(CancellationToken cancellationToken)
{ {
return _packetDistributor.DequeueRawPacketAsync(cancellationToken); return _packetDistributor.DequeuePacketAsync(cancellationToken);
} }
public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
@ -29,4 +30,16 @@ public class PacketDistributorService<TPacketIdEnum, TSession> : Microsoft.Exten
{ {
return Task.CompletedTask; return Task.CompletedTask;
} }
public TPacketIdEnum GetOperationCodeByPacketType(IPacket packet)
{
var type = packet.GetType();
this._packetDistributor.PacketIdMap.TryGetValue(type, out var value);
if (value is null)
{
throw new ArgumentOutOfRangeException(type.Name);
}
return value;
}
} }

View file

@ -1,2 +1,2 @@
#!sh #!sh
docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity . docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity-auth .

View file

@ -1,7 +1,7 @@
services: services:
server: server:
container_name: continuity-server container_name: continuity-auth
image: continuity:latest image: continuity-auth:latest
restart: always restart: always
depends_on: depends_on:
- db - db