using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.ServiceModel.Channels;

namespace EVCB_OCPP.WSServer.Service.WsService;

public class WebsocketService<T> where T : WsSession
{
    public WebsocketService(IServiceProvider serviceProvider, ILogger logger)
    {
        this.serviceProvider = serviceProvider;
        this.logger = logger;
    }

    private readonly IServiceProvider serviceProvider;
    private readonly ILogger logger;

    public event EventHandler<T> NewSessionConnected;

    public async Task AcceptWebSocket(HttpContext context)
    {
        if (!context.WebSockets.IsWebSocketRequest)
        {
            return;
        }
        var portocol = await ValidateSupportedPortocol(context);
        if (string.IsNullOrEmpty(portocol))
        {
            return;
        }

        T data = GetSession(context);

        if (!await ValidateHandshake(context, data))
        {
            return;
        }

        using WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(portocol);
        LogHandshakeResponse(context);

        await AddWebSocket(webSocket, data);
    }

    internal virtual ValueTask<bool> ValidateHandshake(HttpContext context, T data)
    {
        return ValueTask.FromResult(true);
    }

    internal virtual ValueTask<string> ValidateSupportedPortocol(HttpContext context)
    {
        return ValueTask.FromResult(string.Empty);
    }


    private async Task AddWebSocket(WebSocket webSocket, T data)
    {
        data.ClientWebSocket = webSocket;

        NewSessionConnected?.Invoke(this, data);
        await data.EndConnSemaphore.WaitAsync();
        return;
    }

    private T GetSession(HttpContext context)
    {
        T data = serviceProvider.GetRequiredService<T>();
        data.Path = context?.Request?.Path;
        data.SessionID = context.TraceIdentifier;
        data.UriScheme = GetScheme(context);

        try
        {
            var ipaddress = context.Connection.RemoteIpAddress;
            var port = context.Connection.RemotePort;
            data.Endpoint = new IPEndPoint(ipaddress, port);
        }
        catch
        {
            data.Endpoint = null;
        }

        return data;
    }

    private string GetScheme(HttpContext context)
    {
        string toReturn = string.Empty;

        if (context.Request.Headers.ContainsKey("x-original-host"))
        {
            toReturn = new Uri(context.Request.Headers["x-original-host"]).Scheme;
            return toReturn;
        }

        var origin = context.Request.Headers.Origin.FirstOrDefault();
        try
        {
            toReturn = new Uri(origin).Scheme;
            return toReturn;
        }
        catch
        {
        }

        var rawScheme = context.Request.Scheme.ToLower();
        if (rawScheme == "http" ||
            rawScheme == "ws")
        {
            return "ws";
        }
        if (rawScheme == "https" ||
            rawScheme == "wss")
        {
            return "wss";
        }

        return toReturn;
    }

    private void LogHandshakeResponse(HttpContext context)
    {
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Date:", context.Response.Headers["Date"]);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, context.Request.Protocol + " " + context.Response.StatusCode, "Switching Protocols");
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Upgrade:", context.Response.Headers.Upgrade);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Connection:", context.Response.Headers.Connection);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketAccept:", context.Response.Headers.SecWebSocketAccept);
        logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketProtocol:", context.Response.Headers.SecWebSocketProtocol);
    }
}