diff options
Diffstat (limited to 'cpp/src/IceSSL/TransceiverI.cpp')
-rw-r--r-- | cpp/src/IceSSL/TransceiverI.cpp | 264 |
1 files changed, 150 insertions, 114 deletions
diff --git a/cpp/src/IceSSL/TransceiverI.cpp b/cpp/src/IceSSL/TransceiverI.cpp index d4b86a2a427..d930c3d8cff 100644 --- a/cpp/src/IceSSL/TransceiverI.cpp +++ b/cpp/src/IceSSL/TransceiverI.cpp @@ -54,6 +54,11 @@ IceSSL::TransceiverI::close() void IceSSL::TransceiverI::shutdownWrite() { + if(_state < StateConnected) + { + return; + } + if(_instance->networkTraceLevel() >= 2) { Trace out(_logger, _instance->networkTraceCategory()); @@ -69,6 +74,11 @@ IceSSL::TransceiverI::shutdownWrite() void IceSSL::TransceiverI::shutdownReadWrite() { + if(_state < StateConnected) + { + return; + } + if(_instance->networkTraceLevel() >= 2) { Trace out(_logger, _instance->networkTraceCategory()); @@ -81,7 +91,7 @@ IceSSL::TransceiverI::shutdownReadWrite() IceInternal::shutdownSocketReadWrite(_fd); } -void +bool IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout) { // Its impossible for the packetSize to be more than an Int. @@ -102,12 +112,11 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout) ERR_clear_error(); // Clear any spurious errors. assert(_fd != INVALID_SOCKET); int ret, err; - bool wantRead, wantWrite; + bool wantWrite; { IceUtil::Mutex::Lock sync(_sslMutex); ret = SSL_write(_ssl, reinterpret_cast<const void*>(&*buf.i), packetSize); err = SSL_get_error(_ssl, ret); - wantRead = SSL_want_read(_ssl); wantWrite = SSL_want_write(_ssl); } @@ -126,14 +135,15 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout) } case SSL_ERROR_WANT_READ: { - if(!selectRead(_fd, timeout)) - { - throw TimeoutException(__FILE__, __LINE__); - } - continue; + assert(false); + break; } case SSL_ERROR_WANT_WRITE: { + if(timeout == 0) + { + return false; + } if(!selectWrite(_fd, timeout)) { throw TimeoutException(__FILE__, __LINE__); @@ -157,15 +167,12 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout) if(IceInternal::wouldBlock()) { - if(wantRead) + if(wantWrite) { - if(!selectRead(_fd, timeout)) + if(timeout == 0) { - throw TimeoutException(__FILE__, __LINE__); + return false; } - } - else if(wantWrite) - { if(!selectWrite(_fd, timeout)) { throw TimeoutException(__FILE__, __LINE__); @@ -221,9 +228,11 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout) packetSize = static_cast<int>(buf.b.end() - buf.i); } } + + return true; } -void +bool IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) { // It's impossible for the packetSize to be more than an Int. @@ -234,13 +243,12 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) ERR_clear_error(); // Clear any spurious errors. assert(_fd != INVALID_SOCKET); int ret, err; - bool wantRead, wantWrite; + bool wantRead; { IceUtil::Mutex::Lock sync(_sslMutex); ret = SSL_read(_ssl, reinterpret_cast<void*>(&*buf.i), packetSize); err = SSL_get_error(_ssl, ret); wantRead = SSL_want_read(_ssl); - wantWrite = SSL_want_write(_ssl); } if(ret <= 0) @@ -269,6 +277,10 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) } case SSL_ERROR_WANT_READ: { + if(timeout == 0) + { + return false; + } if(!selectRead(_fd, timeout)) { throw TimeoutException(__FILE__, __LINE__); @@ -277,11 +289,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) } case SSL_ERROR_WANT_WRITE: { - if(!selectWrite(_fd, timeout)) - { - throw TimeoutException(__FILE__, __LINE__); - } - continue; + assert(false); + break; } case SSL_ERROR_SYSCALL: { @@ -302,19 +311,15 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) { if(wantRead) { - if(!selectRead(_fd, timeout)) + if(timeout == 0) { - throw TimeoutException(__FILE__, __LINE__); + return false; } - } - else if(wantWrite) - { - if(!selectWrite(_fd, timeout)) + if(!selectRead(_fd, timeout)) { throw TimeoutException(__FILE__, __LINE__); } } - continue; } @@ -361,8 +366,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) //if(ERR_GET_LIB(e) == ERR_LIB_SSL && ERR_GET_REASON(e) == SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC) // unsigned long e = ERR_peek_error(); - if(ERR_GET_LIB(e) == ERR_LIB_SSL && - strcmp(ERR_reason_error_string(e), "decryption failed or bad record mac") == 0) + const char* estr = ERR_GET_LIB(e) == ERR_LIB_SSL ? ERR_reason_error_string(e) : 0; + if(estr && strcmp(estr, "decryption failed or bad record mac") == 0) { ConnectionLostException ex(__FILE__, __LINE__); ex.error = 0; @@ -396,6 +401,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout) packetSize = static_cast<int>(buf.b.end() - buf.i); } } + + return true; } string @@ -410,129 +417,157 @@ IceSSL::TransceiverI::toString() const return _desc; } -void +IceInternal::SocketStatus IceSSL::TransceiverI::initialize(int timeout) { - if(_incoming) + if(_state == StateNeedConnect && timeout == 0) + { + _state = StateConnectPending; + return IceInternal::NeedConnect; + } + else if(_state <= StateConnectPending) + { + IceInternal::doFinishConnect(_fd, timeout); + _state = StateConnected; + _desc = IceInternal::fdToString(_fd); + } + assert(_state == StateConnected); + + do { - // TODO: The timeout is 0 when called by the thread pool. - // Make this configurable? - if(timeout == 0) + // + // Only one thread calls initialize(), so synchronization is not necessary here. + // + int ret = _incoming ? SSL_accept(_ssl) : SSL_connect(_ssl); + switch(SSL_get_error(_ssl, ret)) { - timeout = -1; + case SSL_ERROR_NONE: + assert(SSL_is_init_finished(_ssl)); + break; + case SSL_ERROR_ZERO_RETURN: + { + ConnectionLostException ex(__FILE__, __LINE__); + ex.error = IceInternal::getSocketErrno(); + throw ex; } - - do + case SSL_ERROR_WANT_READ: { - // - // Only one thread calls initialize(), so synchronization is not necessary here. - // - int ret = SSL_accept(_ssl); - switch(SSL_get_error(_ssl, ret)) + if(timeout == 0) { - case SSL_ERROR_NONE: - assert(SSL_is_init_finished(_ssl)); - break; - case SSL_ERROR_ZERO_RETURN: + return IceInternal::NeedRead; + } + if(!selectRead(_fd, timeout)) { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; + throw ConnectTimeoutException(__FILE__, __LINE__); } - case SSL_ERROR_WANT_READ: + break; + } + case SSL_ERROR_WANT_WRITE: + { + if(timeout == 0) { - if(!selectRead(_fd, timeout)) - { - throw ConnectTimeoutException(__FILE__, __LINE__); - } - break; + return IceInternal::NeedWrite; } - case SSL_ERROR_WANT_WRITE: + if(!selectWrite(_fd, timeout)) { - if(!selectWrite(_fd, timeout)) - { - throw ConnectTimeoutException(__FILE__, __LINE__); - } - break; + throw ConnectTimeoutException(__FILE__, __LINE__); } - case SSL_ERROR_SYSCALL: + break; + } + case SSL_ERROR_SYSCALL: + { + if(ret == -1) { - if(ret == -1) + if(IceInternal::interrupted()) { - if(IceInternal::interrupted()) - { - break; - } - - if(IceInternal::wouldBlock()) + break; + } + + if(IceInternal::wouldBlock()) + { + if(SSL_want_read(_ssl)) { - if(SSL_want_read(_ssl)) + if(timeout == 0) { - if(!selectRead(_fd, timeout)) - { - throw ConnectTimeoutException(__FILE__, __LINE__); - } + return IceInternal::NeedRead; } - else if(SSL_want_write(_ssl)) + if(!selectRead(_fd, timeout)) { - if(!selectWrite(_fd, timeout)) - { - throw ConnectTimeoutException(__FILE__, __LINE__); - } + throw ConnectTimeoutException(__FILE__, __LINE__); } - - break; } - - if(IceInternal::connectionLost()) + else if(SSL_want_write(_ssl)) { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; + if(timeout == 0) + { + return IceInternal::NeedWrite; + } + if(!selectWrite(_fd, timeout)) + { + throw ConnectTimeoutException(__FILE__, __LINE__); + } } + + break; } - - if(ret == 0) + + if(IceInternal::connectionLost()) { ConnectionLostException ex(__FILE__, __LINE__); - ex.error = 0; + ex.error = IceInternal::getSocketErrno(); throw ex; } - - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; } - case SSL_ERROR_SSL: + + if(ret == 0) { - struct sockaddr_in remoteAddr; - string desc; - if(IceInternal::fdToRemoteAddress(_fd, remoteAddr)) - { - desc = IceInternal::addrToString(remoteAddr); - } - ProtocolException ex(__FILE__, __LINE__); - ex.reason = "SSL error occurred for new incoming connection:\nremote address = " + desc + "\n" + - _instance->sslErrors(); + ConnectionLostException ex(__FILE__, __LINE__); + ex.error = 0; throw ex; } + + SocketException ex(__FILE__, __LINE__); + ex.error = IceInternal::getSocketErrno(); + throw ex; + } + case SSL_ERROR_SSL: + { + struct sockaddr_in remoteAddr; + string desc; + if(IceInternal::fdToRemoteAddress(_fd, remoteAddr)) + { + desc = IceInternal::addrToString(remoteAddr); } + ProtocolException ex(__FILE__, __LINE__); + ex.reason = "SSL error occurred for new incoming connection:\nremote address = " + desc + "\n" + + _instance->sslErrors(); + throw ex; } - while(!SSL_is_init_finished(_ssl)); - - _instance->verifyPeer(_ssl, _fd, "", _adapterName, true); + } + } + while(!SSL_is_init_finished(_ssl)); + + _instance->verifyPeer(_ssl, _fd, "", _adapterName, _incoming); - if(_instance->networkTraceLevel() >= 1) + if(_instance->networkTraceLevel() >= 1) + { + Trace out(_logger, _instance->networkTraceCategory()); + if(_incoming) { - Trace out(_logger, _instance->networkTraceCategory()); out << "accepted ssl connection\n" << IceInternal::fdToString(_fd); } - - if(_instance->securityTraceLevel() >= 1) + else { - _instance->traceConnection(_ssl, true); + out << "ssl connection established\n" << IceInternal::fdToString(_fd); } } + + if(_instance->securityTraceLevel() >= 1) + { + _instance->traceConnection(_ssl, _incoming); + } + + return IceInternal::Finished; } void @@ -554,7 +589,7 @@ IceSSL::TransceiverI::getConnectionInfo() const return populateConnectionInfo(_ssl, _fd, _adapterName, _incoming); } -IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET fd, +IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET fd, bool connected, bool incoming, const string& adapterName) : _instance(instance), _logger(instance->communicator()->getLogger()), @@ -563,6 +598,7 @@ IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET _fd(fd), _adapterName(adapterName), _incoming(incoming), + _state(connected ? StateConnected : StateNeedConnect), _desc(IceInternal::fdToString(fd)) { #ifdef _WIN32 |