WebsocketService.cs 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. using Microsoft.AspNetCore.Http;
  2. using Microsoft.Extensions.DependencyInjection;
  3. using Microsoft.Extensions.Logging;
  4. using System.Net;
  5. using System.Net.Http;
  6. using System.Net.WebSockets;
  7. using System.ServiceModel.Channels;
  8. namespace EVCB_OCPP.WSServer.Service.WsService;
  9. public class WebsocketService<T> where T : WsSession
  10. {
  11. public WebsocketService(IServiceProvider serviceProvider, ILogger logger)
  12. {
  13. this.serviceProvider = serviceProvider;
  14. this.logger = logger;
  15. }
  16. private readonly IServiceProvider serviceProvider;
  17. private readonly ILogger logger;
  18. public event EventHandler<T> NewSessionConnected;
  19. public async Task AcceptWebSocket(HttpContext context)
  20. {
  21. if (!context.WebSockets.IsWebSocketRequest)
  22. {
  23. return;
  24. }
  25. var portocol = await ValidateSupportedPortocol(context);
  26. if (string.IsNullOrEmpty(portocol))
  27. {
  28. return;
  29. }
  30. T data = GetSession(context);
  31. if (!await ValidateHandshake(context, data))
  32. {
  33. return;
  34. }
  35. using WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(portocol);
  36. LogHandshakeResponse(context);
  37. await AddWebSocket(webSocket, data);
  38. }
  39. internal virtual ValueTask<bool> ValidateHandshake(HttpContext context, T data)
  40. {
  41. return ValueTask.FromResult(true);
  42. }
  43. internal virtual ValueTask<string> ValidateSupportedPortocol(HttpContext context)
  44. {
  45. return ValueTask.FromResult(string.Empty);
  46. }
  47. private async Task AddWebSocket(WebSocket webSocket, T data)
  48. {
  49. data.ClientWebSocket = webSocket;
  50. NewSessionConnected?.Invoke(this, data);
  51. await data.EndConnSemaphore.WaitAsync();
  52. return;
  53. }
  54. private T GetSession(HttpContext context)
  55. {
  56. T data = serviceProvider.GetRequiredService<T>();
  57. data.Path = context?.Request?.Path;
  58. data.SessionID = context.TraceIdentifier;
  59. data.UriScheme = GetScheme(context);
  60. try
  61. {
  62. var ipaddress = context.Connection.RemoteIpAddress;
  63. var port = context.Connection.RemotePort;
  64. data.Endpoint = new IPEndPoint(ipaddress, port);
  65. }
  66. catch
  67. {
  68. data.Endpoint = null;
  69. }
  70. return data;
  71. }
  72. private string GetScheme(HttpContext context)
  73. {
  74. string toReturn = string.Empty;
  75. if (context.Request.Headers.ContainsKey("x-original-host"))
  76. {
  77. toReturn = new Uri(context.Request.Headers["x-original-host"]).Scheme;
  78. return toReturn;
  79. }
  80. var origin = context.Request.Headers.Origin.FirstOrDefault();
  81. try
  82. {
  83. toReturn = new Uri(origin).Scheme;
  84. return toReturn;
  85. }
  86. catch
  87. {
  88. }
  89. var rawScheme = context.Request.Scheme.ToLower();
  90. if (rawScheme == "http" ||
  91. rawScheme == "ws")
  92. {
  93. return "ws";
  94. }
  95. if (rawScheme == "https" ||
  96. rawScheme == "wss")
  97. {
  98. return "wss";
  99. }
  100. return toReturn;
  101. }
  102. private void LogHandshakeResponse(HttpContext context)
  103. {
  104. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Date:", context.Response.Headers["Date"]);
  105. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, context.Request.Protocol + " " + context.Response.StatusCode, "Switching Protocols");
  106. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Upgrade:", context.Response.Headers.Upgrade);
  107. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "Connection:", context.Response.Headers.Connection);
  108. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketAccept:", context.Response.Headers.SecWebSocketAccept);
  109. logger.LogInformation("{0} {1} {2}", context.TraceIdentifier, "SecWebSocketProtocol:", context.Response.Headers.SecWebSocketProtocol);
  110. }
  111. }