feat: cancellation token support

This commit is contained in:
Timothy Schenk 2024-02-05 12:08:00 +01:00
parent b9144a3ff0
commit 23aef12502
Signed by: rainote
SSH key fingerprint: SHA256:pnkNSDwpAnaip00xaZlVFHKKsS7T8UtOomMzvs0yITE
8 changed files with 175 additions and 155 deletions

View file

@ -21,7 +21,8 @@ public partial class ChannelSelectionHandler : IPacketHandler<ChannelSelectionPa
_wonderkingContext = wonderkingContext;
}
public async Task HandleAsync(ChannelSelectionPacket packet, TcpSession session)
public async Task HandleAsync(ChannelSelectionPacket packet, TcpSession session,
CancellationToken cancellationToken)
{
if (session is not AuthSession authSession)
{
@ -32,11 +33,11 @@ public partial class ChannelSelectionHandler : IPacketHandler<ChannelSelectionPa
var guildNameResponsePacket = new CharacterSelectionSetGuildNamePacket { GuildNames = Array.Empty<string>() };
var accountExists =
await _wonderkingContext.Accounts.AsNoTracking().AnyAsync(a => a.Id == authSession.AccountId);
await _wonderkingContext.Accounts.AsNoTracking().AnyAsync(a => a.Id == authSession.AccountId, cancellationToken: cancellationToken);
var amountOfCharacter = await _wonderkingContext.Characters.AsNoTracking().Include(c => c.Account)
.Where(c => c.Account.Id == authSession.AccountId).Take(3)
.CountAsync();
.CountAsync(cancellationToken: cancellationToken);
if (!accountExists)
{
@ -50,14 +51,14 @@ public partial class ChannelSelectionHandler : IPacketHandler<ChannelSelectionPa
ChannelIsFullFlag = 0,
Endpoint = "127.0.0.1",
Port = 2000,
Characters = await GetCharacterDataAsync(authSession.AccountId).ToArrayAsync()
Characters = await GetCharacterDataAsync(authSession.AccountId).ToArrayAsync(token: cancellationToken)
};
guildNameResponsePacket.GuildNames =
await _wonderkingContext.Characters.AsNoTracking().Include(c => c.Account).Include(c => c.GuildMember)
.ThenInclude(gm => gm.Guild)
.Where(c => c.Account.Id == authSession.AccountId && c.GuildMember.Guild != null)
.Select(c => c.GuildMember.Guild.Name).Take(3).ToArrayAsync();
.Select(c => c.GuildMember.Guild.Name).Take(3).ToArrayAsync(cancellationToken: cancellationToken);
}
else
{

View file

@ -28,14 +28,15 @@ public class CharacterCreationHandler : IPacketHandler<CharacterCreationPacket,
_characterStatsMapping = characterStatsMappingConfiguration;
}
public async Task HandleAsync(CharacterCreationPacket packet, TcpSession session)
public async Task HandleAsync(CharacterCreationPacket packet, TcpSession session,
CancellationToken cancellationToken)
{
if (session is not AuthSession authSession)
{
return;
}
var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Id == authSession.AccountId);
var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Id == authSession.AccountId, cancellationToken: cancellationToken);
if (account is null)
{
@ -47,7 +48,7 @@ public class CharacterCreationHandler : IPacketHandler<CharacterCreationPacket,
var toBeAddedCharacter = CreateDefaultCharacter(packet, account, items, firstJobConfig);
account.Characters.Add(toBeAddedCharacter);
await _wonderkingContext.SaveChangesAsync();
await _wonderkingContext.SaveChangesAsync(cancellationToken);
var character =
new CharacterData

View file

@ -18,7 +18,7 @@ public class CharacterDeletionHandler : IPacketHandler<CharacterDeletePacket, Tc
_wonderkingContext = wonderkingContext;
}
public async Task HandleAsync(CharacterDeletePacket packet, TcpSession session)
public async Task HandleAsync(CharacterDeletePacket packet, TcpSession session, CancellationToken cancellationToken)
{
if (session is not AuthSession authSession)
{
@ -27,7 +27,7 @@ public class CharacterDeletionHandler : IPacketHandler<CharacterDeletePacket, Tc
}
var character = await _wonderkingContext.Characters.FirstOrDefaultAsync(x => x.Name == packet.Name &&
x.Account.Id == authSession.AccountId);
x.Account.Id == authSession.AccountId, cancellationToken: cancellationToken);
var response = new CharacterDeleteResponsePacket { HasToBeZero = 0 };
if (character == null)
@ -37,7 +37,7 @@ public class CharacterDeletionHandler : IPacketHandler<CharacterDeletePacket, Tc
}
_wonderkingContext.Characters.Remove(character);
await _wonderkingContext.SaveChangesAsync();
await _wonderkingContext.SaveChangesAsync(cancellationToken);
await authSession.SendAsync(response);
}

View file

@ -18,9 +18,10 @@ public class CharacterNameCheckHandler : IPacketHandler<CharacterNameCheckPacket
_wonderkingContext = wonderkingContext;
}
public async Task HandleAsync(CharacterNameCheckPacket packet, TcpSession session)
public async Task HandleAsync(CharacterNameCheckPacket packet, TcpSession session,
CancellationToken cancellationToken)
{
var isTaken = await _wonderkingContext.Characters.AnyAsync(c => c.Name == packet.Name);
var isTaken = await _wonderkingContext.Characters.AnyAsync(c => c.Name == packet.Name, cancellationToken: cancellationToken);
var responsePacket = new CharacterNameCheckPacketResponse { IsTaken = isTaken };
if (session is AuthSession authSession)
{

View file

@ -32,18 +32,18 @@ public class LoginHandler : IPacketHandler<LoginInfoPacket, TcpSession>
_configuration = configuration;
}
public async Task HandleAsync(LoginInfoPacket packet, TcpSession session)
public async Task HandleAsync(LoginInfoPacket packet, TcpSession session, CancellationToken cancellationToken)
{
LoginResponseReason loginResponseReason;
_logger.LoginData(packet.Username, packet.Password);
var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username);
var account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username, cancellationToken: cancellationToken);
if (account == null)
{
if (_configuration.GetSection("Testing").GetValue<bool>("CreateAccountOnLogin"))
{
loginResponseReason = await CreateAccountOnLoginAsync(packet.Username, packet.Password);
account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username);
account = await _wonderkingContext.Accounts.FirstOrDefaultAsync(a => a.Username == packet.Username, cancellationToken: cancellationToken);
}
else
{

View file

@ -9,9 +9,10 @@ namespace Rai.PacketMediator;
public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandler<TSession> where TIncomingPacket : IIncomingPacket
{
[UsedImplicitly(ImplicitUseTargetFlags.WithInheritors)]
public Task HandleAsync(TIncomingPacket packet, TSession session);
public Task HandleAsync(TIncomingPacket packet, TSession session, CancellationToken cancellationToken);
async Task<bool> IPacketHandler<TSession>.TryHandleAsync(IIncomingPacket packet, TSession session)
async Task<bool> IPacketHandler<TSession>.TryHandleAsync(IIncomingPacket packet, TSession session,
CancellationToken cancellationToken)
{
if (packet is not TIncomingPacket tPacket)
{
@ -21,7 +22,7 @@ public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandle
using var activity = new ActivitySource(nameof(PacketMediator)).StartActivity(nameof(HandleAsync));
activity?.AddTag("Handler", this.ToString());
activity?.AddTag("Packet", packet.ToString());
await HandleAsync(tPacket, session);
await HandleAsync(tPacket, session, cancellationToken);
return true;
}
@ -29,5 +30,5 @@ public interface IPacketHandler<in TIncomingPacket, in TSession> : IPacketHandle
public interface IPacketHandler<in TSession>
{
Task<bool> TryHandleAsync(IIncomingPacket packet, TSession session);
Task<bool> TryHandleAsync(IIncomingPacket packet, TSession session, CancellationToken cancellationToken);
}

View file

@ -0,0 +1,146 @@
// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License.
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Reflection;
using System.Threading.Channels;
using DotNext.Collections.Generic;
using DotNext.Linq.Expressions;
using DotNext.Metaprogramming;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualBasic.CompilerServices;
namespace Rai.PacketMediator;
public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : Enum
{
private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel;
private readonly Assembly[] _sourcesContainingPackets;
private readonly ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation;
private readonly ImmutableDictionary<TPacketIdEnum,
Func<byte[], IIncomingPacket>> _deserializationMap;
public PacketDistributor(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets)
{
_channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(new UnboundedChannelOptions
{
AllowSynchronousContinuations = false,
SingleReader = false,
SingleWriter = false
});
_sourcesContainingPackets = sourcesContainingPackets;
var packetDictionary = GetAllPackets();
if (packetDictionary is { Count: 0 })
{
throw new IncompleteInitialization();
}
var packetHandlers = GetAllPacketHandlersWithId();
if (packetHandlers is { Count: 0 })
{
throw new IncompleteInitialization();
}
var tempDeserializationMap =
new Dictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>();
_packetHandlersInstantiation = new ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?>();
packetHandlers.ForEach(x =>
{
var packetHandler =
ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider,
x.Value);
_packetHandlersInstantiation.TryAdd(x.Key, packetHandler as IPacketHandler<TSession>);
});
foreach (var packetsType in packetDictionary)
{
var lambda = CodeGenerator.Lambda<Func<byte[], IIncomingPacket>>(fun =>
{
var argPacketData = fun[0];
var newPacket = packetsType.Value.New();
var packetVariable = CodeGenerator.DeclareVariable(packetsType.Value, "packet");
CodeGenerator.Assign(packetVariable, newPacket);
CodeGenerator.Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData);
CodeGenerator.Return(packetVariable);
}).Compile();
tempDeserializationMap.Add(packetsType.Key, lambda);
}
_deserializationMap = tempDeserializationMap.ToImmutableDictionary();
}
private Dictionary<TPacketIdEnum, Type> GetAllPackets()
{
// ! : item.Attribute cannot be null due to previous Where check
var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes()
.Where(type => type is { IsInterface: false, IsAbstract: false } &&
type.GetInterfaces().Contains(typeof(IIncomingPacket)))
.Select(type =>
new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() })
.Where(item => item.Attribute is not null)
.ToDictionary(item => item.Attribute!.Code, item => item.Type)).ToDictionary();
return packetsWithId;
}
private Dictionary<TPacketIdEnum, Type> GetAllPacketHandlersWithId()
{
var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes()
.Where(t =>
t is { IsClass: true, IsAbstract: false } && Array.Exists(t
.GetInterfaces(), i =>
i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler<TSession>)))
.Select(type => new
{
Type = type,
PacketId = type
.GetInterfaces().First(t1 =>
t1 is { IsGenericType: true } &&
t1.GetGenericTypeDefinition() == typeof(IPacketHandler<TSession>))
.GetGenericArguments().First(t => t.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any())
.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().First().Code
})
.ToDictionary(
x => x.PacketId, x => x.Type
)).ToDictionary();
return packetHandlersWithId;
}
public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
{
await _channel.Writer.WriteAsync((packetData, operationCode, session));
}
public async Task DequeueRawPacketAsync(CancellationToken cancellationToken)
{
while (await _channel.Reader.WaitToReadAsync(cancellationToken))
{
while (_channel.Reader.TryRead(out var item))
{
await InvokePacketHandlerAsync(item, cancellationToken);
}
}
}
private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple,
CancellationToken cancellationToken)
{
var (packetData, operationCode, session) = valueTuple;
if (!_deserializationMap.TryGetValue(operationCode, out var func))
{
return;
}
var packet = func(packetData);
// ! I don't see how it's possibly null here.
await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session, cancellationToken)!;
}
}

View file

@ -1,158 +1,28 @@
// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License.
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Reflection;
using System.Threading.Channels;
using DotNext.Collections.Generic;
using DotNext.Linq.Expressions;
using DotNext.Metaprogramming;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualBasic.CompilerServices;
namespace Rai.PacketMediator;
using static CodeGenerator;
public class PacketDistributorService<TPacketIdEnum, TSession> : Microsoft.Extensions.Hosting.IHostedService
where TPacketIdEnum : Enum
{
private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel;
private readonly IServiceProvider _serviceProvider;
private readonly Assembly[] _sourcesContainingPackets;
private ImmutableDictionary<TPacketIdEnum,
Func<byte[], IIncomingPacket>> _deserializationMap;
private ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation;
private readonly PacketDistributor<TPacketIdEnum, TSession> _packetDistributor;
public PacketDistributorService(IServiceProvider serviceProvider,
Assembly[] sourcesContainingPackets)
{
_channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(new UnboundedChannelOptions
{
AllowSynchronousContinuations = false,
SingleReader = false,
SingleWriter = false
});
_serviceProvider = serviceProvider;
_sourcesContainingPackets = sourcesContainingPackets;
}
private Dictionary<TPacketIdEnum, Type> RetrievePacketsDictionary()
{
var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes()
.Where(type => type is { IsInterface: false, IsAbstract: false } &&
type.GetInterfaces().Contains(typeof(IIncomingPacket)))
.Select(type =>
new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() })
.Where(item => item.Attribute is not null)
.ToDictionary(item => item.Attribute!.Code, item => item.Type)).ToDictionary();
return packetsWithId;
}
private Dictionary<TPacketIdEnum, Type> GetAllPacketHandlersWithId()
{
var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes()
.Where(t =>
t is { IsClass: true, IsAbstract: false } && Array.Exists(t
.GetInterfaces(), i =>
i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler<TSession>)))
.Select(type => new
{
Type = type,
PacketId = type
.GetInterfaces().First(t1 =>
t1 is { IsGenericType: true } &&
t1.GetGenericTypeDefinition() == typeof(IPacketHandler<TSession>))
.GetGenericArguments().First(t => t.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any())
.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().First().Code
})
.ToDictionary(
x => x.PacketId, x => x.Type
)).ToDictionary();
return packetHandlersWithId;
_packetDistributor = new PacketDistributor<TPacketIdEnum, TSession>(serviceProvider, sourcesContainingPackets);
}
public Task StartAsync(CancellationToken cancellationToken)
{
var packetDictionary = RetrievePacketsDictionary();
if (packetDictionary is { Count: 0 })
{
throw new IncompleteInitialization();
}
var packetHandlers = GetAllPacketHandlersWithId();
if (packetHandlers is { Count: 0 })
{
throw new IncompleteInitialization();
}
var tempDeserializationMap =
new Dictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>();
_packetHandlersInstantiation = new ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?>();
packetHandlers.ForEach(x =>
{
var packetHandler =
ActivatorUtilities.GetServiceOrCreateInstance(_serviceProvider,
x.Value);
_packetHandlersInstantiation.TryAdd(x.Key, packetHandler as IPacketHandler<TSession>);
});
foreach (var packetsType in packetDictionary)
{
var lambda = Lambda<Func<byte[], IIncomingPacket>>(fun =>
{
var argPacketData = fun[0];
var newPacket = packetsType.Value.New();
var packetVariable = DeclareVariable(packetsType.Value, "packet");
Assign(packetVariable, newPacket);
Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData);
Return(packetVariable);
}).Compile();
tempDeserializationMap.Add(packetsType.Key, lambda);
}
_deserializationMap = tempDeserializationMap.ToImmutableDictionary();
_ = this.DequeueRawPacketAsync();
return Task.CompletedTask;
return _packetDistributor.DequeueRawPacketAsync(cancellationToken);
}
public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
public Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
{
await _channel.Writer.WriteAsync((packetData, operationCode, session));
}
private async Task DequeueRawPacketAsync()
{
while (await _channel.Reader.WaitToReadAsync())
{
while (_channel.Reader.TryRead(out var item))
{
await InvokePacketHandlerAsync(item);
}
}
}
private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple)
{
var (packetData, operationCode, session) = valueTuple;
if (!_deserializationMap.TryGetValue(operationCode, out var func))
{
return;
}
var packet = func(packetData);
// ! I don't see how it's possibly null here.
await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session)!;
return this._packetDistributor.AddPacketAsync(packetData, operationCode, session);
}
public Task StopAsync(CancellationToken cancellationToken)