// Copyright (c) 2023 Timothy Schenk. Subject to the GNU AGPL Version 3 License. using System.Collections.Concurrent; using System.Collections.Immutable; using System.Reflection; using System.Threading.Channels; using DotNext.Collections.Generic; using DotNext.Linq.Expressions; using DotNext.Metaprogramming; using Microsoft.Extensions.DependencyInjection; using Microsoft.VisualBasic.CompilerServices; namespace Rai.PacketMediator; using static CodeGenerator; public class PacketDistributorService : Microsoft.Extensions.Hosting.IHostedService where TPacketIdEnum : Enum { private readonly Channel> _channel; private readonly IServiceProvider _serviceProvider; private readonly Assembly[] _sourcesContainingPackets; private ImmutableDictionary> _deserializationMap; private ConcurrentDictionary?> _packetHandlersInstantiation; public PacketDistributorService(IServiceProvider serviceProvider, Assembly[] sourcesContainingPackets) { _channel = Channel.CreateUnbounded>(new UnboundedChannelOptions { AllowSynchronousContinuations = false, SingleReader = false, SingleWriter = false }); _serviceProvider = serviceProvider; _sourcesContainingPackets = sourcesContainingPackets; } private Dictionary RetrievePacketsDictionary() { var packetsWithId = this._sourcesContainingPackets.SelectMany(a => a.GetTypes() .Where(type => type is { IsInterface: false, IsAbstract: false } && type.GetInterfaces().Contains(typeof(IIncomingPacket))) .Select(type => new { Type = type, Attribute = type.GetCustomAttribute>() }) .Where(item => item.Attribute is not null) .ToDictionary(item => item.Attribute!.Code, item => item.Type)).ToDictionary(); return packetsWithId; } private Dictionary GetAllPacketHandlersWithId() { var packetHandlersWithId = this._sourcesContainingPackets.SelectMany(assembly => assembly.GetTypes() .Where(t => t is { IsClass: true, IsAbstract: false } && Array.Exists(t .GetInterfaces(), i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler))) .Select(type => new { Type = type, PacketId = type .GetInterfaces().First(t1 => t1 is { IsGenericType: true } && t1.GetGenericTypeDefinition() == typeof(IPacketHandler)) .GetGenericArguments().First(t => t.GetCustomAttributes>().Any()) .GetCustomAttributes>().First().Code }) .ToDictionary( x => x.PacketId, x => x.Type )).ToDictionary(); return packetHandlersWithId; } public Task StartAsync(CancellationToken cancellationToken) { var packetDictionary = RetrievePacketsDictionary(); if (packetDictionary is { Count: 0 }) { throw new IncompleteInitialization(); } var packetHandlers = GetAllPacketHandlersWithId(); if (packetHandlers is { Count: 0 }) { throw new IncompleteInitialization(); } var tempDeserializationMap = new Dictionary>(); _packetHandlersInstantiation = new ConcurrentDictionary?>(); packetHandlers.ForEach(x => { var packetHandler = ActivatorUtilities.GetServiceOrCreateInstance(_serviceProvider, x.Value); _packetHandlersInstantiation.TryAdd(x.Key, packetHandler as IPacketHandler); }); foreach (var packetsType in packetDictionary) { var lambda = Lambda>(fun => { var argPacketData = fun[0]; var newPacket = packetsType.Value.New(); var packetVariable = DeclareVariable(packetsType.Value, "packet"); Assign(packetVariable, newPacket); Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData); Return(packetVariable); }).Compile(); tempDeserializationMap.Add(packetsType.Key, lambda); } _deserializationMap = tempDeserializationMap.ToImmutableDictionary(); _ = this.DequeueRawPacketAsync(); return Task.CompletedTask; } public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session) { await _channel.Writer.WriteAsync((packetData, operationCode, session)); } private async Task DequeueRawPacketAsync() { while (await _channel.Reader.WaitToReadAsync()) { while (_channel.Reader.TryRead(out var item)) { await InvokePacketHandlerAsync(item); } } } private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple) { var (packetData, operationCode, session) = valueTuple; if (!_deserializationMap.TryGetValue(operationCode, out var func)) { return; } var packet = func(packetData); // ! I don't see how it's possibly null here. await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session)!; } public Task StopAsync(CancellationToken cancellationToken) { return Task.CompletedTask; } }