using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using System.Data.Entity.Core.Common.EntitySql;
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.ServiceModel.Channels;

namespace EVCB_OCPP.WSServer.Service.WsService;

public class WebsocketService<T> where T : WsSession
{
    public WebsocketService(IServiceProvider serviceProvider, ILogger logger)
    {
        this.serviceProvider = serviceProvider;
        this.logger = logger;
    }

    private readonly IServiceProvider serviceProvider;
    private readonly ILogger logger;

    public event EventHandler<T> NewSessionConnected;

    public async Task AcceptWebSocket(HttpContext context)
    {
        if (!context.WebSockets.IsWebSocketRequest)
        {
            return;
        }
        var portocol = await ValidateSupportedPortocol(context);
        if (string.IsNullOrEmpty(portocol))
        {
            return;
        }

        T data = GetSession(context);

        if (!await ValidateHandshake(context, data))
        {
            return;
        }

        using WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(portocol);
        LogHandshakeResponse(context);

        await AddWebSocket(webSocket, data);
    }

    internal virtual ValueTask<bool> ValidateHandshake(HttpContext context, T data)
    {
        return ValueTask.FromResult(true);
    }

    internal virtual ValueTask<string> ValidateSupportedPortocol(HttpContext context)
    {
        return ValueTask.FromResult(string.Empty);
    }


    private async Task AddWebSocket(WebSocket webSocket, T data)
    {
        data.ClientWebSocket = webSocket;

        NewSessionConnected?.Invoke(this, data);
        await data.EndConnSemaphore.WaitAsync();
        return;
    }

    private T GetSession(HttpContext context)
    {
        T data = serviceProvider.GetRequiredService<T>();
        data.Path = context?.Request?.Path;
        data.SessionID = context.TraceIdentifier;
        data.UriScheme = GetScheme(context);

        try
        {
            var proxyPassClientIp = context.Request.Headers["X-Forwarded-For"];
            foreach (var infoString in proxyPassClientIp)
            {
                foreach (var testIp in infoString.Split(','))
                {
                    logger.LogDebug("X-Forwarded-For {ip}", testIp);
                    if (IPEndPoint.TryParse(testIp, out var parseResult) &&
                        (parseResult.AddressFamily is System.Net.Sockets.AddressFamily.InterNetwork or System.Net.Sockets.AddressFamily.InterNetworkV6)
                        )
                    {
                        data.Endpoint = parseResult;
                        break;
                    }
                }

                if (data.Endpoint != null)
                {
                    break;
                }
            }

            if (data.Endpoint is null)
            {
                var ipaddress = context.Connection.RemoteIpAddress;
                var port = context.Connection.RemotePort;
                data.Endpoint = new IPEndPoint(ipaddress, port);
            }
        }
        catch
        {
            data.Endpoint = null;
        }

        return data;
    }

    private string GetScheme(HttpContext context)
    {
        string toReturn = string.Empty;

        if (context.Request.Headers.ContainsKey("x-original-host"))
        {
            toReturn = new Uri(context.Request.Headers["x-original-host"]).Scheme;
            return toReturn;
        }

        var origin = context.Request.Headers.Origin.FirstOrDefault();
        try
        {
            toReturn = new Uri(origin).Scheme;
            return toReturn;
        }
        catch
        {
        }

        var rawScheme = context.Request.Scheme.ToLower();
        if (rawScheme == "http" ||
            rawScheme == "ws")
        {
            return "ws";
        }
        if (rawScheme == "https" ||
            rawScheme == "wss")
        {
            return "wss";
        }

        return toReturn;
    }

    private void LogHandshakeResponse(HttpContext context)
    {
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Date:", context.Response.Headers["Date"]);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, context.Request.Protocol + " " + context.Response.StatusCode, "Switching Protocols");
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Upgrade:", context.Response.Headers.Upgrade);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Connection:", context.Response.Headers.Connection);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketAccept:", context.Response.Headers.SecWebSocketAccept);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketProtocol:", context.Response.Headers.SecWebSocketProtocol);
    }
}