// Licensed to Timothy Schenk under the Apache 2.0 License.

using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Reflection;
using System.Threading.Channels;
using DotNext.Collections.Generic;
using DotNext.Linq.Expressions;
using DotNext.Metaprogramming;
using Microsoft.Extensions.DependencyInjection;

namespace RaiNote.PacketMediator;

public class PacketDistributor<TPacketIdEnum, TSession> where TPacketIdEnum : Enum
{
    private readonly Channel<ValueTuple<byte[], TPacketIdEnum, TSession>> _channel;

    private readonly ImmutableDictionary<TPacketIdEnum,
        Func<byte[], IIncomingPacket>> _deserializationMap;

    private readonly ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?> _packetHandlersInstantiation;

    public PacketDistributor(IServiceProvider serviceProvider,
        IEnumerable<Assembly> sourcesContainingPackets, IEnumerable<Assembly> sourcesContainingPacketHandlers)
    {
        _channel = Channel.CreateUnbounded<ValueTuple<byte[], TPacketIdEnum, TSession>>(
            new UnboundedChannelOptions
            {
                AllowSynchronousContinuations = false,
                SingleReader = false,
                SingleWriter = false
            }
        );
        var containingPackets = sourcesContainingPackets as Assembly[] ?? sourcesContainingPackets.ToArray();
        var allIncomingPackets = GetAllPackets(containingPackets, typeof(IIncomingPacket));
        var allOutgoingPackets = GetAllPackets(containingPackets, typeof(IOutgoingPacket));

        var packetHandlers = GetAllPacketHandlersWithId(sourcesContainingPacketHandlers);

        PacketIdMap = allOutgoingPackets.Select(x => new { PacketId = x.Key, Type = x.Value })
            .ToImmutableDictionary(x => x.Type, x => x.PacketId);

        var tempDeserializationMap =
            new ConcurrentDictionary<TPacketIdEnum, Func<byte[], IIncomingPacket>>();
        _packetHandlersInstantiation = new ConcurrentDictionary<TPacketIdEnum, IPacketHandler<TSession>?>();
        packetHandlers.ForEach(packetHandlerPair =>
            {
                var packetHandler =
                    ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider,
                        packetHandlerPair.Value
                    );
                _packetHandlersInstantiation.TryAdd(packetHandlerPair.Key, packetHandler as IPacketHandler<TSession>);
            }
        );
        allIncomingPackets.ForEach(packetsType =>
            {
                var lambda = CodeGenerator.Lambda<Func<byte[], IIncomingPacket>>(fun =>
                        {
                            var argPacketData = fun[0];
                            var newPacket = packetsType.Value.New();

                            var packetVariable = CodeGenerator.DeclareVariable(packetsType.Value, "packet");
                            CodeGenerator.Assign(packetVariable, newPacket);
                            CodeGenerator.Call(packetVariable, nameof(IIncomingPacket.Deserialize), argPacketData);

                            CodeGenerator.Return(packetVariable);
                        }
                    )
                    .Compile();
                tempDeserializationMap.TryAdd(packetsType.Key, lambda);
            }
        );

        _deserializationMap = tempDeserializationMap.ToImmutableDictionary();
    }

    public ImmutableDictionary<Type, TPacketIdEnum> PacketIdMap { get; }

    private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPackets(
        IEnumerable<Assembly> sourcesContainingPackets, Type packetType)
    {
        var packetsWithId = sourcesContainingPackets.SelectMany(a => a.GetTypes()
                .Where(type => type is { IsInterface: false, IsAbstract: false } &&
                               type.GetInterfaces().Contains(packetType) &&
                               type.GetCustomAttributes<PacketIdAttribute<TPacketIdEnum>>().Any()
                )
            )
            .Select(type =>
                new { Type = type, Attribute = type.GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>() }
            )
            .Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.Attribute!.Code, x.Type));

        return packetsWithId;
    }

    private static IEnumerable<KeyValuePair<TPacketIdEnum, Type>> GetAllPacketHandlersWithId(
        IEnumerable<Assembly> sourcesContainingPacketHandlers)
    {
        var packetHandlersWithId = sourcesContainingPacketHandlers.SelectMany(assembly => assembly.GetTypes()
                .Where(t =>
                    t is { IsClass: true, IsAbstract: false } &&
                    Array.Exists(t
                            .GetInterfaces(),
                        i =>
                            i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IPacketHandler<,>)
                    )
                )
                .Select(packetHandlerType => new
                {
                    Type = packetHandlerType,
                    PacketId = packetHandlerType
                            .GetInterfaces()
                            .First(t1 =>
                                t1 is { IsGenericType: true } &&
                                t1.GetGenericTypeDefinition() == typeof(IPacketHandler<,>)
                            )
                            .GetGenericArguments()
                            .First(genericType => genericType.GetInterfaces()
                                .Any(packetType =>
                                    packetType == typeof(IPacket)
                                )
                            )
                            .GetCustomAttribute<PacketIdAttribute<TPacketIdEnum>>()
                }
                )
            )
            .Where(x => x.PacketId != null)
            .Select(x => new KeyValuePair<TPacketIdEnum, Type>(x.PacketId!.Code, x.Type));

        return packetHandlersWithId;
    }

    public async Task AddPacketAsync(byte[] packetData, TPacketIdEnum operationCode, TSession session)
    {
        await _channel.Writer.WriteAsync((packetData, operationCode, session));
    }

    public async Task DequeuePacketAsync(CancellationToken cancellationToken)
    {
        while (await _channel.Reader.WaitToReadAsync(cancellationToken))
        {
            while (_channel.Reader.TryRead(out var item))
            {
                await InvokePacketHandlerAsync(item, cancellationToken);
            }
        }
    }

    private async Task InvokePacketHandlerAsync((byte[], TPacketIdEnum, TSession) valueTuple,
        CancellationToken cancellationToken)
    {
        var (packetData, operationCode, session) = valueTuple;
        if (!_deserializationMap.TryGetValue(operationCode, out var func))
        {
            return;
        }

        var packet = func(packetData);

        await _packetHandlersInstantiation[operationCode]?.TryHandleAsync(packet, session, cancellationToken)!;
    }
}