diff --git a/Continuity.AuthServer/PacketHandlers/ChannelSelectionHandler.cs b/Continuity.AuthServer/PacketHandlers/ChannelSelectionHandler.cs index e5e121d..ad13ac4 100644 --- a/Continuity.AuthServer/PacketHandlers/ChannelSelectionHandler.cs +++ b/Continuity.AuthServer/PacketHandlers/ChannelSelectionHandler.cs @@ -21,7 +21,8 @@ public partial class ChannelSelectionHandler : IPacketHandler() }; var accountExists = - await _wonderkingContext.Accounts.AsNoTracking().AnyAsync(a => a.Id == authSession.AccountId); + await _wonderkingContext.Accounts.AsNoTracking().AnyAsync(a => a.Id == authSession.AccountId, cancellationToken: cancellationToken); var amountOfCharacter = await _wonderkingContext.Characters.AsNoTracking().Include(c => c.Account) .Where(c => c.Account.Id == authSession.AccountId).Take(3) - .CountAsync(); + .CountAsync(cancellationToken: cancellationToken); if (!accountExists) { @@ -50,14 +51,14 @@ public partial class ChannelSelectionHandler : IPacketHandler c.Account).Include(c => c.GuildMember) .ThenInclude(gm => gm.Guild) .Where(c => c.Account.Id == authSession.AccountId && c.GuildMember.Guild != null) - .Select(c => c.GuildMember.Guild.Name).Take(3).ToArrayAsync(); + .Select(c => c.GuildMember.Guild.Name).Take(3).ToArrayAsync(cancellationToken: cancellationToken); } else { diff --git a/Continuity.AuthServer/PacketHandlers/CharacterCreationHandler.cs b/Continuity.AuthServer/PacketHandlers/CharacterCreationHandler.cs index 169249f..8c821e4 100644 --- a/Continuity.AuthServer/PacketHandlers/CharacterCreationHandler.cs +++ b/Continuity.AuthServer/PacketHandlers/CharacterCreationHandler.cs @@ -28,14 +28,15 @@ public class CharacterCreationHandler : IPacketHandler a.Id == authSession.AccountId); + var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Id == authSession.AccountId, cancellationToken: cancellationToken); if (account is null) { @@ -47,7 +48,7 @@ public class CharacterCreationHandler : IPacketHandler x.Name == packet.Name && - x.Account.Id == authSession.AccountId); + x.Account.Id == authSession.AccountId, cancellationToken: cancellationToken); var response = new CharacterDeleteResponsePacket { HasToBeZero = 0 }; if (character == null) @@ -37,7 +37,7 @@ public class CharacterDeletionHandler : IPacketHandler c.Name == packet.Name); + var isTaken = await _wonderkingContext.Characters.AnyAsync(c => c.Name == packet.Name, cancellationToken: cancellationToken); var responsePacket = new CharacterNameCheckPacketResponse { IsTaken = isTaken }; if (session is AuthSession authSession) { diff --git a/Continuity.AuthServer/PacketHandlers/LoginHandler.cs b/Continuity.AuthServer/PacketHandlers/LoginHandler.cs index 647a07e..6185441 100644 --- a/Continuity.AuthServer/PacketHandlers/LoginHandler.cs +++ b/Continuity.AuthServer/PacketHandlers/LoginHandler.cs @@ -32,18 +32,18 @@ public class LoginHandler : IPacketHandler _configuration = configuration; } - public async Task HandleAsync(LoginInfoPacket packet, TcpSession session) + public async Task HandleAsync(LoginInfoPacket packet, TcpSession session, CancellationToken cancellationToken) { LoginResponseReason loginResponseReason; _logger.LoginData(packet.Username, packet.Password); - var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username); + var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username, cancellationToken: cancellationToken); if (account == null) { if (_configuration.GetSection("Testing").GetValue("CreateAccountOnLogin")) { loginResponseReason = await CreateAccountOnLoginAsync(packet.Username, packet.Password); - account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username); + account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username, cancellationToken: cancellationToken); } else { diff --git a/Rai.PacketMediator/IPacketHandler.cs b/Rai.PacketMediator/IPacketHandler.cs index d2ad071..8bcbb33 100644 --- a/Rai.PacketMediator/IPacketHandler.cs +++ b/Rai.PacketMediator/IPacketHandler.cs @@ -9,9 +9,10 @@ namespace Rai.PacketMediator; public interface IPacketHandler : IPacketHandler where TIncomingPacket : IIncomingPacket { [UsedImplicitly(ImplicitUseTargetFlags.WithInheritors)] - public Task HandleAsync(TIncomingPacket packet, TSession session); + public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken); - async Task IPacketHandler.TryHandleAsync(IIncomingPacket packet, TSession session) + async Task IPacketHandler.TryHandleAsync(IIncomingPacket packet, TSession session, + CancellationToken cancellationToken) { if (packet is not TIncomingPacket tPacket) { @@ -21,7 +22,7 @@ public interface IPacketHandler : IPacketHandle using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync)); activity?.AddTag("Handler", this.ToString()); activity?.AddTag("Packet", packet.ToString()); - await HandleAsync(tPacket, session); + await HandleAsync(tPacket, session, cancellationToken); return true; } @@ -29,5 +30,5 @@ public interface IPacketHandler : IPacketHandle public interface IPacketHandler { - Task TryHandleAsync(IIncomingPacket packet, TSession session); + Task TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken); } diff --git a/Rai.PacketMediator/PacketDistributor.cs b/Rai.PacketMediator/PacketDistributor.cs new file mode 100644 index 0000000..e1861b3 --- /dev/null +++ b/Rai.PacketMediator/PacketDistributor.cs @@ -0,0 +1,146 @@ +// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License. + +using System.Collections.Concurrent; +using System.Collections.Immutable; +using System.Reflection; +using System.Threading.Channels; +using DotNext.Collections.Generic; +using DotNext.Linq.Expressions; +using DotNext.Metaprogramming; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualBasic.CompilerServices; + +namespace Rai.PacketMediator; + +public class PacketDistributor where TPacketIdEnum : Enum +{ + private readonly Channel> _channel; + + private readonly Assembly[] _sourcesContainingPackets; + private readonly ConcurrentDictionary?> _packetHandlersInstantiation; + + private readonly ImmutableDictionary> _deserializationMap; + + public PacketDistributor(IServiceProvider serviceProvider, + Assembly[] sourcesContainingPackets) + { + _channel = Channel.CreateUnbounded>(new UnboundedChannelOptions + { + AllowSynchronousContinuations = false, + SingleReader = false, + SingleWriter = false + }); + _sourcesContainingPackets = sourcesContainingPackets; + var packetDictionary = GetAllPackets(); + + if (packetDictionary is { Count: 0 }) + { + throw new IncompleteInitialization(); + } + + var packetHandlers = GetAllPacketHandlersWithId(); + + if (packetHandlers is { Count: 0 }) + { + throw new IncompleteInitialization(); + } + + var tempDeserializationMap = + new Dictionary>(); + _packetHandlersInstantiation = new ConcurrentDictionary?>(); + packetHandlers.ForEach(x => + { + var packetHandler = + ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider, + x.Value); + _packetHandlersInstantiation.TryAdd(x.Key, packetHandler as IPacketHandler); + }); + foreach (var packetsType in packetDictionary) + { + var lambda = CodeGenerator.Lambda>(fun => + { + var argPacketData = fun[0]; + var newPacket = packetsType.Value.New(); + + var packetVariable = CodeGenerator.DeclareVariable(packetsType.Value, "packet"); + CodeGenerator.Assign(packetVariable, newPacket); + CodeGenerator.Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData); + + CodeGenerator.Return(packetVariable); + }).Compile(); + tempDeserializationMap.Add(packetsType.Key, lambda); + } + + _deserializationMap = tempDeserializationMap.ToImmutableDictionary(); + } + + private Dictionary GetAllPackets() + { + // ! : item.Attribute cannot be null due to previous Where check + var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes() + .Where(type => type is { IsInterface: false, IsAbstract: false } && + type.GetInterfaces().Contains(typeof(IIncomingPacket))) + .Select(type => + new { Type = type, Attribute = type.GetCustomAttribute>() }) + .Where(item => item.Attribute is not null) + .ToDictionary(item => item.Attribute!.Code, item => item.Type)).ToDictionary(); + + return packetsWithId; + } + + private Dictionary GetAllPacketHandlersWithId() + { + var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes() + .Where(t => + t is { IsClass: true, IsAbstract: false } && Array.Exists(t + .GetInterfaces(), i => + i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler))) + .Select(type => new + { + Type = type, + PacketId = type + .GetInterfaces().First(t1 => + t1 is { IsGenericType: true } && + t1.GetGenericTypeDefinition() == typeof(IPacketHandler)) + .GetGenericArguments().First(t => t.GetCustomAttributes>().Any()) + .GetCustomAttributes>().First().Code + }) + .ToDictionary( + x => x.PacketId, x => x.Type + )).ToDictionary(); + + return packetHandlersWithId; + } + + public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) + { + await _channel.Writer.WriteAsync((packetData, operationCode, session)); + } + + public async Task DequeueRawPacketAsync(CancellationToken cancellationToken) + { + while (await _channel.Reader.WaitToReadAsync(cancellationToken)) + { + while (_channel.Reader.TryRead(out var item)) + { + await InvokePacketHandlerAsync(item, cancellationToken); + } + } + } + + private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple, + CancellationToken cancellationToken) + { + var (packetData, operationCode, session) = valueTuple; + if (!_deserializationMap.TryGetValue(operationCode, out var func)) + { + return; + } + + var packet = func(packetData); + + // ! I don't see how it's possibly null here. + await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session, cancellationToken)!; + } +} diff --git a/Rai.PacketMediator/PacketDistributorService.cs b/Rai.PacketMediator/PacketDistributorService.cs index 8a87d95..e7f2a48 100644 --- a/Rai.PacketMediator/PacketDistributorService.cs +++ b/Rai.PacketMediator/PacketDistributorService.cs @@ -1,158 +1,28 @@ // Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License. -using System.Collections.Concurrent; -using System.Collections.Immutable; using System.Reflection; -using System.Threading.Channels; -using DotNext.Collections.Generic; -using DotNext.Linq.Expressions; -using DotNext.Metaprogramming; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.VisualBasic.CompilerServices; namespace Rai.PacketMediator; -using static CodeGenerator; - public class PacketDistributorService : Microsoft.Extensions.Hosting.IHostedService where TPacketIdEnum : Enum { - private readonly Channel> _channel; - - private readonly IServiceProvider _serviceProvider; - private readonly Assembly[] _sourcesContainingPackets; - - private ImmutableDictionary> _deserializationMap; - - private ConcurrentDictionary?> _packetHandlersInstantiation; + private readonly PacketDistributor _packetDistributor; public PacketDistributorService(IServiceProvider serviceProvider, Assembly[] sourcesContainingPackets) { - _channel = Channel.CreateUnbounded>(new UnboundedChannelOptions - { - AllowSynchronousContinuations = false, - SingleReader = false, - SingleWriter = false - }); - _serviceProvider = serviceProvider; - _sourcesContainingPackets = sourcesContainingPackets; - } - - private Dictionary RetrievePacketsDictionary() - { - var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes() - .Where(type => type is { IsInterface: false, IsAbstract: false } && - type.GetInterfaces().Contains(typeof(IIncomingPacket))) - .Select(type => - new { Type = type, Attribute = type.GetCustomAttribute>() }) - .Where(item => item.Attribute is not null) - .ToDictionary(item => item.Attribute!.Code, item => item.Type)).ToDictionary(); - - return packetsWithId; - } - - private Dictionary GetAllPacketHandlersWithId() - { - var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes() - .Where(t => - t is { IsClass: true, IsAbstract: false } && Array.Exists(t - .GetInterfaces(), i => - i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler))) - .Select(type => new - { - Type = type, - PacketId = type - .GetInterfaces().First(t1 => - t1 is { IsGenericType: true } && - t1.GetGenericTypeDefinition() == typeof(IPacketHandler)) - .GetGenericArguments().First(t => t.GetCustomAttributes>().Any()) - .GetCustomAttributes>().First().Code - }) - .ToDictionary( - x => x.PacketId, x => x.Type - )).ToDictionary(); - - return packetHandlersWithId; + _packetDistributor = new PacketDistributor(serviceProvider, sourcesContainingPackets); } public Task StartAsync(CancellationToken cancellationToken) { - var packetDictionary = RetrievePacketsDictionary(); - - if (packetDictionary is { Count: 0 }) - { - throw new IncompleteInitialization(); - } - - var packetHandlers = GetAllPacketHandlersWithId(); - - if (packetHandlers is { Count: 0 }) - { - throw new IncompleteInitialization(); - } - - var tempDeserializationMap = - new Dictionary>(); - _packetHandlersInstantiation = new ConcurrentDictionary?>(); - packetHandlers.ForEach(x => - { - var packetHandler = - ActivatorUtilities.GetServiceOrCreateInstance(_serviceProvider, - x.Value); - _packetHandlersInstantiation.TryAdd(x.Key, packetHandler as IPacketHandler); - }); - foreach (var packetsType in packetDictionary) - { - var lambda = Lambda>(fun => - { - var argPacketData = fun[0]; - var newPacket = packetsType.Value.New(); - - var packetVariable = DeclareVariable(packetsType.Value, "packet"); - Assign(packetVariable, newPacket); - Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData); - - Return(packetVariable); - }).Compile(); - tempDeserializationMap.Add(packetsType.Key, lambda); - } - - _deserializationMap = tempDeserializationMap.ToImmutableDictionary(); - - _ = this.DequeueRawPacketAsync(); - return Task.CompletedTask; + return _packetDistributor.DequeueRawPacketAsync(cancellationToken); } - public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) + public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) { - await _channel.Writer.WriteAsync((packetData, operationCode, session)); - } - - private async Task DequeueRawPacketAsync() - { - while (await _channel.Reader.WaitToReadAsync()) - { - while (_channel.Reader.TryRead(out var item)) - { - await InvokePacketHandlerAsync(item); - } - } - } - - private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple) - { - var (packetData, operationCode, session) = valueTuple; - if (!_deserializationMap.TryGetValue(operationCode, out var func)) - { - return; - } - - var packet = func(packetData); - - // ! I don't see how it's possibly null here. - await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session)!; + return this._packetDistributor.AddPacketAsync(packetData, operationCode, session); } public Task StopAsync(CancellationToken cancellationToken)