using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using SuperSocket.Common; using SuperSocket.SocketBase; using SuperSocket.SocketBase.Command; using SuperSocket.SocketBase.Config; using SuperSocket.SocketBase.Logging; using SuperSocket.SocketBase.Protocol; using SuperSocket.SocketEngine.AsyncSocket; namespace SuperSocket.SocketEngine { /// /// The interface for socket session which requires negotiation before communication /// interface INegotiateSocketSession { /// /// Start negotiates /// void Negotiate(); /// /// Gets a value indicating whether this is result. /// /// /// true if result; otherwise, false. /// bool Result { get; } /// /// Gets the app session. /// /// /// The app session. /// IAppSession AppSession { get; } /// /// Occurs when [negotiate completed]. /// event EventHandler NegotiateCompleted; } class AsyncStreamSocketSession : SocketSession, IAsyncSocketSessionBase, INegotiateSocketSession { private byte[] m_ReadBuffer; private int m_Offset; private int m_Length; private bool m_IsReset; public AsyncStreamSocketSession(Socket client, SslProtocols security, SocketAsyncEventArgsProxy socketAsyncProxy) : this(client, security, socketAsyncProxy, false) { } public AsyncStreamSocketSession(Socket client, SslProtocols security, SocketAsyncEventArgsProxy socketAsyncProxy, bool isReset) : base(client) { SecureProtocol = security; SocketAsyncProxy = socketAsyncProxy; var e = socketAsyncProxy.SocketEventArgs; m_ReadBuffer = e.Buffer; m_Offset = e.Offset; m_Length = e.Count; m_IsReset = isReset; } /// /// Starts this session communication. /// public override void Start() { //Hasn't started, but already closed if (IsClosed) return; OnSessionStarting(); } private void OnSessionStarting() { try { OnReceiveStarted(); m_Stream.BeginRead(m_ReadBuffer, m_Offset, m_Length, OnStreamEndRead, m_Stream); } catch (Exception e) { LogError(e); OnReceiveTerminated(CloseReason.SocketError); return; } if (!m_IsReset) StartSession(); } private void OnStreamEndRead(IAsyncResult result) { var stream = result.AsyncState as Stream; int thisRead = 0; try { thisRead = stream.EndRead(result); } catch (Exception e) { LogError(e); OnReceiveTerminated(CloseReason.SocketError); return; } if (thisRead <= 0) { OnReceiveTerminated(CloseReason.ClientClosing); return; } OnReceiveEnded(); int offsetDelta; try { offsetDelta = AppSession.ProcessRequest(m_ReadBuffer, m_Offset, thisRead, true); } catch (Exception ex) { LogError("Protocol error", ex); this.Close(CloseReason.ProtocolError); return; } try { if (offsetDelta < 0 || offsetDelta >= Config.ReceiveBufferSize) throw new ArgumentException(string.Format("Illigal offsetDelta: {0}", offsetDelta), "offsetDelta"); m_Offset = SocketAsyncProxy.OrigOffset + offsetDelta; m_Length = Config.ReceiveBufferSize - offsetDelta; OnReceiveStarted(); m_Stream.BeginRead(m_ReadBuffer, m_Offset, m_Length, OnStreamEndRead, m_Stream); } catch (Exception exc) { LogError(exc); OnReceiveTerminated(CloseReason.SocketError); return; } } private Stream m_Stream; private SslStream CreateSslStream(ICertificateConfig certConfig) { //Enable client certificate function only if ClientCertificateRequired is true in the configuration if (!certConfig.ClientCertificateRequired) return new SslStream(new NetworkStream(Client), false); //Subscribe the client validation callback return new SslStream(new NetworkStream(Client), false, ValidateClientCertificate); } private bool ValidateClientCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { var session = AppSession; //Invoke the AppServer's method ValidateClientCertificate var clientCertificateValidator = session.AppServer as IRemoteCertificateValidator; if (clientCertificateValidator != null) return clientCertificateValidator.Validate(session, sender, certificate, chain, sslPolicyErrors); //Return the native validation result return sslPolicyErrors == SslPolicyErrors.None; } private IAsyncResult BeginInitStream(AsyncCallback asyncCallback) { IAsyncResult result = null; var certConfig = AppSession.Config.Certificate; var secureProtocol = SecureProtocol; switch (secureProtocol) { case (SslProtocols.None): m_Stream = new NetworkStream(Client); break; case (SslProtocols.Default): case (SslProtocols.Tls): case (SslProtocols.Ssl3): SslStream sslStream = CreateSslStream(certConfig); result = sslStream.BeginAuthenticateAsServer(AppSession.AppServer.Certificate, certConfig.ClientCertificateRequired, SslProtocols.Default, false, asyncCallback, sslStream); break; case (SslProtocols.Ssl2): SslStream ssl2Stream = CreateSslStream(certConfig); result = ssl2Stream.BeginAuthenticateAsServer(AppSession.AppServer.Certificate, certConfig.ClientCertificateRequired, SslProtocols.Ssl2, false, asyncCallback, ssl2Stream); break; default: var unknownSslStream = CreateSslStream(certConfig); result = unknownSslStream.BeginAuthenticateAsServer(AppSession.AppServer.Certificate, certConfig.ClientCertificateRequired, secureProtocol, false, asyncCallback, unknownSslStream); break; } return result; } private void OnBeginInitStreamOnSessionConnected(IAsyncResult result) { OnBeginInitStream(result, true); } private void OnBeginInitStream(IAsyncResult result) { OnBeginInitStream(result, false); } private void OnBeginInitStream(IAsyncResult result, bool connect) { var sslStream = result.AsyncState as SslStream; try { sslStream.EndAuthenticateAsServer(result); } catch (IOException exc) { LogError(Client.RemoteEndPoint.ToString(), exc); if (!connect)//Session was already registered this.Close(CloseReason.SocketError); OnNegotiateCompleted(false); return; } catch (Exception e) { LogError(Client.RemoteEndPoint.ToString(), e); if (!connect)//Session was already registered this.Close(CloseReason.SocketError); OnNegotiateCompleted(false); return; } m_Stream = sslStream; OnNegotiateCompleted(true); } protected override void SendSync(SendingQueue queue) { try { for (var i = 0; i < queue.Count; i++) { var item = queue[i]; m_Stream.Write(item.Array, item.Offset, item.Count); } OnSendingCompleted(queue); } catch (Exception e) { LogError(e); OnSendError(queue, CloseReason.SocketError); return; } } protected override void OnSendingCompleted(SendingQueue queue) { try { m_Stream.Flush(); } catch (Exception e) { LogError(e); OnSendError(queue, CloseReason.SocketError); return; } base.OnSendingCompleted(queue); } protected override void SendAsync(SendingQueue queue) { try { var item = queue[queue.Position]; m_Stream.BeginWrite(item.Array, item.Offset, item.Count, OnEndWrite, queue); } catch (Exception e) { LogError(e); OnSendError(queue, CloseReason.SocketError); } } private void OnEndWrite(IAsyncResult result) { var queue = result.AsyncState as SendingQueue; try { m_Stream.EndWrite(result); } catch (Exception e) { LogError(e); OnSendError(queue, CloseReason.SocketError); return; } var nextPos = queue.Position + 1; //Has more data to send if (nextPos < queue.Count) { queue.Position = nextPos; SendAsync(queue); return; } OnSendingCompleted(queue); } public override void ApplySecureProtocol() { var asyncResult = BeginInitStream(OnBeginInitStream); if (asyncResult != null) asyncResult.AsyncWaitHandle.WaitOne(); } public SocketAsyncEventArgsProxy SocketAsyncProxy { get; private set; } ILog ILoggerProvider.Logger { get { return AppSession.Logger; } } public override int OrigReceiveOffset { get { return SocketAsyncProxy.OrigOffset; } } private bool m_NegotiateResult = false; void INegotiateSocketSession.Negotiate() { IAsyncResult asyncResult; try { asyncResult = BeginInitStream(OnBeginInitStreamOnSessionConnected); } catch (Exception e) { LogError(Client.RemoteEndPoint.ToString(), e); OnNegotiateCompleted(false); return; } if (asyncResult == null) { OnNegotiateCompleted(true); return; } } bool INegotiateSocketSession.Result { get { return m_NegotiateResult; } } private EventHandler m_NegotiateCompleted; event EventHandler INegotiateSocketSession.NegotiateCompleted { add { m_NegotiateCompleted += value; } remove { m_NegotiateCompleted -= value; } } private void OnNegotiateCompleted(bool negotiateResult) { m_NegotiateResult = negotiateResult; //One time event handler var handler = Interlocked.Exchange(ref m_NegotiateCompleted, null); if (handler == null) return; handler(this, EventArgs.Empty); } } }