summaryrefslogtreecommitdiff
path: root/cpp/src/IceSSL/TransceiverI.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'cpp/src/IceSSL/TransceiverI.cpp')
-rw-r--r--cpp/src/IceSSL/TransceiverI.cpp264
1 files changed, 150 insertions, 114 deletions
diff --git a/cpp/src/IceSSL/TransceiverI.cpp b/cpp/src/IceSSL/TransceiverI.cpp
index d4b86a2a427..d930c3d8cff 100644
--- a/cpp/src/IceSSL/TransceiverI.cpp
+++ b/cpp/src/IceSSL/TransceiverI.cpp
@@ -54,6 +54,11 @@ IceSSL::TransceiverI::close()
void
IceSSL::TransceiverI::shutdownWrite()
{
+ if(_state < StateConnected)
+ {
+ return;
+ }
+
if(_instance->networkTraceLevel() >= 2)
{
Trace out(_logger, _instance->networkTraceCategory());
@@ -69,6 +74,11 @@ IceSSL::TransceiverI::shutdownWrite()
void
IceSSL::TransceiverI::shutdownReadWrite()
{
+ if(_state < StateConnected)
+ {
+ return;
+ }
+
if(_instance->networkTraceLevel() >= 2)
{
Trace out(_logger, _instance->networkTraceCategory());
@@ -81,7 +91,7 @@ IceSSL::TransceiverI::shutdownReadWrite()
IceInternal::shutdownSocketReadWrite(_fd);
}
-void
+bool
IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout)
{
// Its impossible for the packetSize to be more than an Int.
@@ -102,12 +112,11 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout)
ERR_clear_error(); // Clear any spurious errors.
assert(_fd != INVALID_SOCKET);
int ret, err;
- bool wantRead, wantWrite;
+ bool wantWrite;
{
IceUtil::Mutex::Lock sync(_sslMutex);
ret = SSL_write(_ssl, reinterpret_cast<const void*>(&*buf.i), packetSize);
err = SSL_get_error(_ssl, ret);
- wantRead = SSL_want_read(_ssl);
wantWrite = SSL_want_write(_ssl);
}
@@ -126,14 +135,15 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout)
}
case SSL_ERROR_WANT_READ:
{
- if(!selectRead(_fd, timeout))
- {
- throw TimeoutException(__FILE__, __LINE__);
- }
- continue;
+ assert(false);
+ break;
}
case SSL_ERROR_WANT_WRITE:
{
+ if(timeout == 0)
+ {
+ return false;
+ }
if(!selectWrite(_fd, timeout))
{
throw TimeoutException(__FILE__, __LINE__);
@@ -157,15 +167,12 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout)
if(IceInternal::wouldBlock())
{
- if(wantRead)
+ if(wantWrite)
{
- if(!selectRead(_fd, timeout))
+ if(timeout == 0)
{
- throw TimeoutException(__FILE__, __LINE__);
+ return false;
}
- }
- else if(wantWrite)
- {
if(!selectWrite(_fd, timeout))
{
throw TimeoutException(__FILE__, __LINE__);
@@ -221,9 +228,11 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf, int timeout)
packetSize = static_cast<int>(buf.b.end() - buf.i);
}
}
+
+ return true;
}
-void
+bool
IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
{
// It's impossible for the packetSize to be more than an Int.
@@ -234,13 +243,12 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
ERR_clear_error(); // Clear any spurious errors.
assert(_fd != INVALID_SOCKET);
int ret, err;
- bool wantRead, wantWrite;
+ bool wantRead;
{
IceUtil::Mutex::Lock sync(_sslMutex);
ret = SSL_read(_ssl, reinterpret_cast<void*>(&*buf.i), packetSize);
err = SSL_get_error(_ssl, ret);
wantRead = SSL_want_read(_ssl);
- wantWrite = SSL_want_write(_ssl);
}
if(ret <= 0)
@@ -269,6 +277,10 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
}
case SSL_ERROR_WANT_READ:
{
+ if(timeout == 0)
+ {
+ return false;
+ }
if(!selectRead(_fd, timeout))
{
throw TimeoutException(__FILE__, __LINE__);
@@ -277,11 +289,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
}
case SSL_ERROR_WANT_WRITE:
{
- if(!selectWrite(_fd, timeout))
- {
- throw TimeoutException(__FILE__, __LINE__);
- }
- continue;
+ assert(false);
+ break;
}
case SSL_ERROR_SYSCALL:
{
@@ -302,19 +311,15 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
{
if(wantRead)
{
- if(!selectRead(_fd, timeout))
+ if(timeout == 0)
{
- throw TimeoutException(__FILE__, __LINE__);
+ return false;
}
- }
- else if(wantWrite)
- {
- if(!selectWrite(_fd, timeout))
+ if(!selectRead(_fd, timeout))
{
throw TimeoutException(__FILE__, __LINE__);
}
}
-
continue;
}
@@ -361,8 +366,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
//if(ERR_GET_LIB(e) == ERR_LIB_SSL && ERR_GET_REASON(e) == SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC)
//
unsigned long e = ERR_peek_error();
- if(ERR_GET_LIB(e) == ERR_LIB_SSL &&
- strcmp(ERR_reason_error_string(e), "decryption failed or bad record mac") == 0)
+ const char* estr = ERR_GET_LIB(e) == ERR_LIB_SSL ? ERR_reason_error_string(e) : 0;
+ if(estr && strcmp(estr, "decryption failed or bad record mac") == 0)
{
ConnectionLostException ex(__FILE__, __LINE__);
ex.error = 0;
@@ -396,6 +401,8 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, int timeout)
packetSize = static_cast<int>(buf.b.end() - buf.i);
}
}
+
+ return true;
}
string
@@ -410,129 +417,157 @@ IceSSL::TransceiverI::toString() const
return _desc;
}
-void
+IceInternal::SocketStatus
IceSSL::TransceiverI::initialize(int timeout)
{
- if(_incoming)
+ if(_state == StateNeedConnect && timeout == 0)
+ {
+ _state = StateConnectPending;
+ return IceInternal::NeedConnect;
+ }
+ else if(_state <= StateConnectPending)
+ {
+ IceInternal::doFinishConnect(_fd, timeout);
+ _state = StateConnected;
+ _desc = IceInternal::fdToString(_fd);
+ }
+ assert(_state == StateConnected);
+
+ do
{
- // TODO: The timeout is 0 when called by the thread pool.
- // Make this configurable?
- if(timeout == 0)
+ //
+ // Only one thread calls initialize(), so synchronization is not necessary here.
+ //
+ int ret = _incoming ? SSL_accept(_ssl) : SSL_connect(_ssl);
+ switch(SSL_get_error(_ssl, ret))
{
- timeout = -1;
+ case SSL_ERROR_NONE:
+ assert(SSL_is_init_finished(_ssl));
+ break;
+ case SSL_ERROR_ZERO_RETURN:
+ {
+ ConnectionLostException ex(__FILE__, __LINE__);
+ ex.error = IceInternal::getSocketErrno();
+ throw ex;
}
-
- do
+ case SSL_ERROR_WANT_READ:
{
- //
- // Only one thread calls initialize(), so synchronization is not necessary here.
- //
- int ret = SSL_accept(_ssl);
- switch(SSL_get_error(_ssl, ret))
+ if(timeout == 0)
{
- case SSL_ERROR_NONE:
- assert(SSL_is_init_finished(_ssl));
- break;
- case SSL_ERROR_ZERO_RETURN:
+ return IceInternal::NeedRead;
+ }
+ if(!selectRead(_fd, timeout))
{
- ConnectionLostException ex(__FILE__, __LINE__);
- ex.error = IceInternal::getSocketErrno();
- throw ex;
+ throw ConnectTimeoutException(__FILE__, __LINE__);
}
- case SSL_ERROR_WANT_READ:
+ break;
+ }
+ case SSL_ERROR_WANT_WRITE:
+ {
+ if(timeout == 0)
{
- if(!selectRead(_fd, timeout))
- {
- throw ConnectTimeoutException(__FILE__, __LINE__);
- }
- break;
+ return IceInternal::NeedWrite;
}
- case SSL_ERROR_WANT_WRITE:
+ if(!selectWrite(_fd, timeout))
{
- if(!selectWrite(_fd, timeout))
- {
- throw ConnectTimeoutException(__FILE__, __LINE__);
- }
- break;
+ throw ConnectTimeoutException(__FILE__, __LINE__);
}
- case SSL_ERROR_SYSCALL:
+ break;
+ }
+ case SSL_ERROR_SYSCALL:
+ {
+ if(ret == -1)
{
- if(ret == -1)
+ if(IceInternal::interrupted())
{
- if(IceInternal::interrupted())
- {
- break;
- }
-
- if(IceInternal::wouldBlock())
+ break;
+ }
+
+ if(IceInternal::wouldBlock())
+ {
+ if(SSL_want_read(_ssl))
{
- if(SSL_want_read(_ssl))
+ if(timeout == 0)
{
- if(!selectRead(_fd, timeout))
- {
- throw ConnectTimeoutException(__FILE__, __LINE__);
- }
+ return IceInternal::NeedRead;
}
- else if(SSL_want_write(_ssl))
+ if(!selectRead(_fd, timeout))
{
- if(!selectWrite(_fd, timeout))
- {
- throw ConnectTimeoutException(__FILE__, __LINE__);
- }
+ throw ConnectTimeoutException(__FILE__, __LINE__);
}
-
- break;
}
-
- if(IceInternal::connectionLost())
+ else if(SSL_want_write(_ssl))
{
- ConnectionLostException ex(__FILE__, __LINE__);
- ex.error = IceInternal::getSocketErrno();
- throw ex;
+ if(timeout == 0)
+ {
+ return IceInternal::NeedWrite;
+ }
+ if(!selectWrite(_fd, timeout))
+ {
+ throw ConnectTimeoutException(__FILE__, __LINE__);
+ }
}
+
+ break;
}
-
- if(ret == 0)
+
+ if(IceInternal::connectionLost())
{
ConnectionLostException ex(__FILE__, __LINE__);
- ex.error = 0;
+ ex.error = IceInternal::getSocketErrno();
throw ex;
}
-
- SocketException ex(__FILE__, __LINE__);
- ex.error = IceInternal::getSocketErrno();
- throw ex;
}
- case SSL_ERROR_SSL:
+
+ if(ret == 0)
{
- struct sockaddr_in remoteAddr;
- string desc;
- if(IceInternal::fdToRemoteAddress(_fd, remoteAddr))
- {
- desc = IceInternal::addrToString(remoteAddr);
- }
- ProtocolException ex(__FILE__, __LINE__);
- ex.reason = "SSL error occurred for new incoming connection:\nremote address = " + desc + "\n" +
- _instance->sslErrors();
+ ConnectionLostException ex(__FILE__, __LINE__);
+ ex.error = 0;
throw ex;
}
+
+ SocketException ex(__FILE__, __LINE__);
+ ex.error = IceInternal::getSocketErrno();
+ throw ex;
+ }
+ case SSL_ERROR_SSL:
+ {
+ struct sockaddr_in remoteAddr;
+ string desc;
+ if(IceInternal::fdToRemoteAddress(_fd, remoteAddr))
+ {
+ desc = IceInternal::addrToString(remoteAddr);
}
+ ProtocolException ex(__FILE__, __LINE__);
+ ex.reason = "SSL error occurred for new incoming connection:\nremote address = " + desc + "\n" +
+ _instance->sslErrors();
+ throw ex;
}
- while(!SSL_is_init_finished(_ssl));
-
- _instance->verifyPeer(_ssl, _fd, "", _adapterName, true);
+ }
+ }
+ while(!SSL_is_init_finished(_ssl));
+
+ _instance->verifyPeer(_ssl, _fd, "", _adapterName, _incoming);
- if(_instance->networkTraceLevel() >= 1)
+ if(_instance->networkTraceLevel() >= 1)
+ {
+ Trace out(_logger, _instance->networkTraceCategory());
+ if(_incoming)
{
- Trace out(_logger, _instance->networkTraceCategory());
out << "accepted ssl connection\n" << IceInternal::fdToString(_fd);
}
-
- if(_instance->securityTraceLevel() >= 1)
+ else
{
- _instance->traceConnection(_ssl, true);
+ out << "ssl connection established\n" << IceInternal::fdToString(_fd);
}
}
+
+ if(_instance->securityTraceLevel() >= 1)
+ {
+ _instance->traceConnection(_ssl, _incoming);
+ }
+
+ return IceInternal::Finished;
}
void
@@ -554,7 +589,7 @@ IceSSL::TransceiverI::getConnectionInfo() const
return populateConnectionInfo(_ssl, _fd, _adapterName, _incoming);
}
-IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET fd,
+IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET fd, bool connected,
bool incoming, const string& adapterName) :
_instance(instance),
_logger(instance->communicator()->getLogger()),
@@ -563,6 +598,7 @@ IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SSL* ssl, SOCKET
_fd(fd),
_adapterName(adapterName),
_incoming(incoming),
+ _state(connected ? StateConnected : StateNeedConnect),
_desc(IceInternal::fdToString(fd))
{
#ifdef _WIN32