123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- using Microsoft.AspNetCore.Http;
- using Microsoft.Extensions.DependencyInjection;
- using Microsoft.Extensions.Logging;
- using System.Data.Entity.Core.Common.EntitySql;
- using System.Net;
- using System.Net.Http;
- using System.Net.WebSockets;
- using System.ServiceModel.Channels;
- namespace EVCB_OCPP.WSServer.Service.WsService;
- public class WebsocketService<T> where T : WsSession
- {
- public WebsocketService(IServiceProvider serviceProvider, ILogger logger)
- {
- this.serviceProvider = serviceProvider;
- this.logger = logger;
- }
- private readonly IServiceProvider serviceProvider;
- private readonly ILogger logger;
- public event EventHandler<T> NewSessionConnected;
- public async Task AcceptWebSocket(HttpContext context)
- {
- if (!context.WebSockets.IsWebSocketRequest)
- {
- return;
- }
- var portocol = await ValidateSupportedPortocol(context);
- if (string.IsNullOrEmpty(portocol))
- {
- return;
- }
- T data = GetSession(context);
- if (!await ValidateHandshake(context, data))
- {
- return;
- }
- using WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(portocol);
- LogHandshakeResponse(context);
- await AddWebSocket(webSocket, data);
- }
- internal virtual ValueTask<bool> ValidateHandshake(HttpContext context, T data)
- {
- return ValueTask.FromResult(true);
- }
- internal virtual ValueTask<string> ValidateSupportedPortocol(HttpContext context)
- {
- return ValueTask.FromResult(string.Empty);
- }
- private async Task AddWebSocket(WebSocket webSocket, T data)
- {
- data.ClientWebSocket = webSocket;
- NewSessionConnected?.Invoke(this, data);
- await data.EndConnSemaphore.WaitAsync();
- return;
- }
- private T GetSession(HttpContext context)
- {
- T data = serviceProvider.GetRequiredService<T>();
- data.Path = context?.Request?.Path;
- data.SessionID = context.TraceIdentifier;
- data.UriScheme = GetScheme(context);
- try
- {
- var proxyPassClientIp = context.Request.Headers["X-Forwarded-For"];
- foreach (var infoString in proxyPassClientIp)
- {
- foreach (var testIp in infoString.Split(','))
- {
- logger.LogDebug("X-Forwarded-For {ip}", testIp);
- if (IPEndPoint.TryParse(testIp, out var parseResult) &&
- (parseResult.AddressFamily is System.Net.Sockets.AddressFamily.InterNetwork or System.Net.Sockets.AddressFamily.InterNetworkV6)
- )
- {
- data.Endpoint = parseResult;
- break;
- }
- }
- if (data.Endpoint != null)
- {
- break;
- }
- }
- if (data.Endpoint is null)
- {
- var ipaddress = context.Connection.RemoteIpAddress;
- var port = context.Connection.RemotePort;
- data.Endpoint = new IPEndPoint(ipaddress, port);
- }
- }
- catch
- {
- data.Endpoint = null;
- }
- return data;
- }
- private string GetScheme(HttpContext context)
- {
- string toReturn = string.Empty;
- var rawScheme = string.Empty;
- if (context.Request.Headers.ContainsKey("x-original-host"))
- {
- rawScheme = new Uri(context.Request.Headers["x-original-host"]).Scheme;
- }
-
- if (context.Request.Headers.ContainsKey("X-Forwarded-Proto"))
- {
- rawScheme = context.Request.Headers["X-Forwarded-Proto"];
- }
- var origin = context.Request.Headers.Origin.FirstOrDefault();
- try
- {
- toReturn = new Uri(origin).Scheme;
- return toReturn;
- }
- catch
- {
- }
- if (string.IsNullOrEmpty(rawScheme))
- {
- rawScheme = context.Request.Scheme;
- }
- rawScheme = rawScheme.ToLower();
- if (rawScheme == "http" ||
- rawScheme == "ws")
- {
- return "ws";
- }
- if (rawScheme == "https" ||
- rawScheme == "wss")
- {
- return "wss";
- }
- return toReturn;
- }
- private void LogHandshakeResponse(HttpContext context)
- {
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Date:", context.Response.Headers["Date"]);
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, context.Request.Protocol + " " + context.Response.StatusCode, "Switching Protocols");
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Upgrade:", context.Response.Headers.Upgrade);
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Connection:", context.Response.Headers.Connection);
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketAccept:", context.Response.Headers.SecWebSocketAccept);
- logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketProtocol:", context.Response.Headers.SecWebSocketProtocol);
- }
- }
|