diff --git a/Continuity.AuthServer/AuthSession.cs b/Continuity.AuthServer/AuthSession.cs index 164b319..1af7f8c 100644 --- a/Continuity.AuthServer/AuthSession.cs +++ b/Continuity.AuthServer/AuthSession.cs @@ -1,7 +1,6 @@ // Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License. using System.Net.Sockets; -using System.Reflection; using Continuity.AuthServer.Packets; using MassTransit.Mediator; using Microsoft.Extensions.Logging; @@ -14,13 +13,16 @@ namespace Continuity.AuthServer; public class AuthSession : TcpSession { private readonly ILogger _logger; + private readonly PacketDistributorService _distributorService; private readonly IMediator _mediator; public AuthSession(TcpServer - server, IMediator mediator, ILogger logger) : base(server) + server, IMediator mediator, ILogger logger, + PacketDistributorService distributorService) : base(server) { _mediator = mediator; _logger = logger; + _distributorService = distributorService; } public Guid AccountId { get; set; } @@ -33,15 +35,7 @@ public class AuthSession : TcpSession public Task SendAsync(IOutgoingPacket packet) { - var type = packet.GetType(); - _logger.LogInformation("Packet of type {Type} is being serialized", type.Name); - var packetIdAttribute = type.GetCustomAttribute(); - if (packetIdAttribute == null) - { - return Task.CompletedTask; - } - - var opcode = packetIdAttribute.Code; + var opcode = _distributorService.GetOperationCodeByPacketType(packet); Span packetData = packet.Serialize(); var length = (ushort)(packetData.Length + 8); diff --git a/Continuity.AuthServer/PacketHandlers/LoginHandler.cs b/Continuity.AuthServer/PacketHandlers/LoginHandler.cs index a8da16d..f0bef15 100644 --- a/Continuity.AuthServer/PacketHandlers/LoginHandler.cs +++ b/Continuity.AuthServer/PacketHandlers/LoginHandler.cs @@ -75,7 +75,7 @@ public class LoginHandler : IPacketHandler } _logger.LogInformation("LoginResponsePacket: {@LoginResponsePacket}", loginResponsePacket); - _ = session?.SendAsync(loginResponsePacket); + _ = session.SendAsync(loginResponsePacket); } private static async Task GetPasswordHashAsync(string password, byte[] salt, Guid userId) diff --git a/Continuity.AuthServer/Program.cs b/Continuity.AuthServer/Program.cs index a53a11f..3aa856e 100644 --- a/Continuity.AuthServer/Program.cs +++ b/Continuity.AuthServer/Program.cs @@ -110,10 +110,8 @@ builder.Services.AddSingleton(loggerFactory); builder.Services.AddSingleton(provider => new PacketDistributorService( provider.GetRequiredService(), - [ - Assembly.GetAssembly(typeof(LoginHandler)), - Assembly.GetAssembly(typeof(OperationCode)), - ] + new List { Assembly.GetAssembly(typeof(OperationCode)) }.AsReadOnly(), + new List { Assembly.GetAssembly(typeof(LoginHandler)) }.AsReadOnly() )); builder.Services.AddSingleton(); @@ -125,7 +123,8 @@ builder.Services.AddHostedService(provider => provider.GetService() ?? throw new InvalidOperationException()); builder.Services.AddHostedService(provider => - provider.GetService>() ?? throw new InvalidOperationException()); + provider.GetService>() ?? + throw new InvalidOperationException()); builder.Services.AddMassTransit(x => { diff --git a/Rai.PacketMediator/PacketDistributor.cs b/Rai.PacketMediator/PacketDistributor.cs index 3c86da6..d827797 100644 --- a/Rai.PacketMediator/PacketDistributor.cs +++ b/Rai.PacketMediator/PacketDistributor.cs @@ -8,7 +8,6 @@ using DotNext.Collections.Generic; using DotNext.Linq.Expressions; using DotNext.Metaprogramming; using Microsoft.Extensions.DependencyInjection; -using Microsoft.VisualBasic.CompilerServices; namespace Rai.PacketMediator; @@ -16,14 +15,15 @@ public class PacketDistributor where TPacketIdEnum : En { private readonly Channel> _channel; - private readonly Assembly[] _sourcesContainingPackets; private readonly ConcurrentDictionary?> _packetHandlersInstantiation; private readonly ImmutableDictionary> _deserializationMap; + public ImmutableDictionary PacketIdMap { get; } + public PacketDistributor(IServiceProvider serviceProvider, - Assembly[] sourcesContainingPackets) + IEnumerable sourcesContainingPackets, IEnumerable sourcesContainingPacketHandlers) { _channel = Channel.CreateUnbounded>(new UnboundedChannelOptions { @@ -31,20 +31,14 @@ public class PacketDistributor where TPacketIdEnum : En SingleReader = false, SingleWriter = false }); - _sourcesContainingPackets = sourcesContainingPackets; - var packetDictionary = GetAllPackets(); + var containingPackets = sourcesContainingPackets as Assembly[] ?? sourcesContainingPackets.ToArray(); + var allIncomingPackets = GetAllPackets(containingPackets, typeof(IIncomingPacket)); + var allOutgoingPackets = GetAllPackets(containingPackets, typeof(IOutgoingPacket)); - if (packetDictionary is { Count: 0 }) - { - throw new IncompleteInitialization(); - } + var packetHandlers = GetAllPacketHandlersWithId(sourcesContainingPacketHandlers); - var packetHandlers = GetAllPacketHandlersWithId(); - - if (packetHandlers is { Count: 0 }) - { - throw new IncompleteInitialization(); - } + this.PacketIdMap = allOutgoingPackets.Select(x => new { PacketId = x.Key, Type = x.Value }) + .ToImmutableDictionary(x => x.Type, x => x.PacketId); var tempDeserializationMap = new ConcurrentDictionary>(); @@ -56,7 +50,7 @@ public class PacketDistributor where TPacketIdEnum : En packetHandlerPair.Value); _packetHandlersInstantiation.TryAdd(packetHandlerPair.Key, packetHandler as IPacketHandler); }); - packetDictionary.ForEach(packetsType => + allIncomingPackets.ForEach(packetsType => { var lambda = CodeGenerator.Lambda>(fun => { @@ -75,23 +69,25 @@ public class PacketDistributor where TPacketIdEnum : En _deserializationMap = tempDeserializationMap.ToImmutableDictionary(); } - private Dictionary GetAllPackets() + private static IEnumerable> GetAllPackets( + IEnumerable sourcesContainingPackets, Type packetType) { - var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes() + var packetsWithId = sourcesContainingPackets.SelectMany(a => a.GetTypes() .Where(type => type is { IsInterface: false, IsAbstract: false } && - type.GetInterfaces().Contains(typeof(IIncomingPacket)) + type.GetInterfaces().Contains(packetType) && type.GetCustomAttributes>().Any() )) .Select(type => new { Type = type, Attribute = type.GetCustomAttribute>() }) - .ToDictionary(item => item.Attribute!.Code, item => item.Type); + .Select(x => new KeyValuePair(x.Attribute!.Code, x.Type)); return packetsWithId; } - private Dictionary GetAllPacketHandlersWithId() + private static IEnumerable> GetAllPacketHandlersWithId( + IEnumerable sourcesContainingPacketHandlers) { - var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes() + var packetHandlersWithId = sourcesContainingPacketHandlers.SelectMany(assembly => assembly.GetTypes() .Where(t => t is { IsClass: true, IsAbstract: false } && Array.Exists(t .GetInterfaces(), i => @@ -108,7 +104,7 @@ public class PacketDistributor where TPacketIdEnum : En .GetCustomAttribute>() })) .Where(x => x.PacketId != null) - .ToDictionary(x => x.PacketId!.Code, x => x.Type); + .Select(x => new KeyValuePair(x.PacketId!.Code, x.Type)); return packetHandlersWithId; } @@ -118,7 +114,7 @@ public class PacketDistributor where TPacketIdEnum : En await _channel.Writer.WriteAsync((packetData, operationCode, session)); } - public async Task DequeueRawPacketAsync(CancellationToken cancellationToken) + public async Task DequeuePacketAsync(CancellationToken cancellationToken) { while (await _channel.Reader.WaitToReadAsync(cancellationToken)) { diff --git a/Rai.PacketMediator/PacketDistributorService.cs b/Rai.PacketMediator/PacketDistributorService.cs index e7f2a48..bcc60c1 100644 --- a/Rai.PacketMediator/PacketDistributorService.cs +++ b/Rai.PacketMediator/PacketDistributorService.cs @@ -10,14 +10,15 @@ public class PacketDistributorService : Microsoft.Exten private readonly PacketDistributor _packetDistributor; public PacketDistributorService(IServiceProvider serviceProvider, - Assembly[] sourcesContainingPackets) + IEnumerable sourcesContainingPackets, IEnumerable sourcesContainingPacketHandlers) { - _packetDistributor = new PacketDistributor(serviceProvider, sourcesContainingPackets); + _packetDistributor = new PacketDistributor(serviceProvider, sourcesContainingPackets, + sourcesContainingPacketHandlers); } public Task StartAsync(CancellationToken cancellationToken) { - return _packetDistributor.DequeueRawPacketAsync(cancellationToken); + return _packetDistributor.DequeuePacketAsync(cancellationToken); } public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) @@ -29,4 +30,16 @@ public class PacketDistributorService : Microsoft.Exten { return Task.CompletedTask; } + + public TPacketIdEnum GetOperationCodeByPacketType(IPacket packet) + { + var type = packet.GetType(); + this._packetDistributor.PacketIdMap.TryGetValue(type, out var value); + if (value is null) + { + throw new ArgumentOutOfRangeException(type.Name); + } + + return value; + } } diff --git a/build-image.ps1 b/build-image.ps1 index 982761d..2b6c1e0 100644 --- a/build-image.ps1 +++ b/build-image.ps1 @@ -1,2 +1,2 @@ #!sh -docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity . +docker build --platform linux/arm64,linux/amd64 -f Continuity.AuthServer/Dockerfile -t continuity-auth . diff --git a/docker-compose.yml b/docker-compose.yml index 0e3de18..6c5f870 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ services: server: - container_name: continuity-server - image: continuity:latest + container_name: continuity-auth + image: continuity-auth:latest restart: always depends_on: - db