using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using System.Data.Entity.Core.Common.EntitySql; using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.ServiceModel.Channels; namespace EVCB_OCPP.WSServer.Service.WsService; public class WebsocketService where T : WsSession { public WebsocketService(IServiceProvider serviceProvider, ILogger logger) { this.serviceProvider = serviceProvider; this.logger = logger; } private readonly IServiceProvider serviceProvider; private readonly ILogger logger; public event EventHandler NewSessionConnected; public async Task AcceptWebSocket(HttpContext context) { if (!context.WebSockets.IsWebSocketRequest) { return; } var portocol = await ValidateSupportedPortocol(context); if (string.IsNullOrEmpty(portocol)) { return; } T data = GetSession(context); if (!await ValidateHandshake(context, data)) { return; } using WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(portocol); LogHandshakeResponse(context); await AddWebSocket(webSocket, data); } internal virtual ValueTask ValidateHandshake(HttpContext context, T data) { return ValueTask.FromResult(true); } internal virtual ValueTask ValidateSupportedPortocol(HttpContext context) { return ValueTask.FromResult(string.Empty); } private async Task AddWebSocket(WebSocket webSocket, T data) { data.ClientWebSocket = webSocket; NewSessionConnected?.Invoke(this, data); await data.EndConnSemaphore.WaitAsync(); return; } private T GetSession(HttpContext context) { T data = serviceProvider.GetRequiredService(); data.Path = context?.Request?.Path; data.SessionID = context.TraceIdentifier; data.UriScheme = GetScheme(context); try { var proxyPassClientIp = context.Request.Headers["X-Forwarded-For"]; foreach (var infoString in proxyPassClientIp) { foreach (var testIp in infoString.Split(',')) { logger.LogDebug("X-Forwarded-For {ip}", testIp); if (IPEndPoint.TryParse(testIp, out var parseResult) && (parseResult.AddressFamily is System.Net.Sockets.AddressFamily.InterNetwork or System.Net.Sockets.AddressFamily.InterNetworkV6) ) { data.Endpoint = parseResult; break; } } if (data.Endpoint != null) { break; } } if (data.Endpoint is null) { var ipaddress = context.Connection.RemoteIpAddress; var port = context.Connection.RemotePort; data.Endpoint = new IPEndPoint(ipaddress, port); } } catch { data.Endpoint = null; } return data; } private string GetScheme(HttpContext context) { string toReturn = string.Empty; if (context.Request.Headers.ContainsKey("x-original-host")) { toReturn = new Uri(context.Request.Headers["x-original-host"]).Scheme; return toReturn; } var origin = context.Request.Headers.Origin.FirstOrDefault(); try { toReturn = new Uri(origin).Scheme; return toReturn; } catch { } var rawScheme = context.Request.Scheme.ToLower(); if (rawScheme == "http" || rawScheme == "ws") { return "ws"; } if (rawScheme == "https" || rawScheme == "wss") { return "wss"; } return toReturn; } private void LogHandshakeResponse(HttpContext context) { logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Date:", context.Response.Headers["Date"]); logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, context.Request.Protocol + " " + context.Response.StatusCode, "Switching Protocols"); logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Upgrade:", context.Response.Headers.Upgrade); logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Connection:", context.Response.Headers.Connection); logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketAccept:", context.Response.Headers.SecWebSocketAccept); logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketProtocol:", context.Response.Headers.SecWebSocketProtocol); } }