diff options
Diffstat (limited to 'cpp/src')
-rw-r--r-- | cpp/src/IceSSL/ClientContext.cpp | 6 | ||||
-rw-r--r-- | cpp/src/IceSSL/ClientContext.h | 4 | ||||
-rw-r--r-- | cpp/src/IceSSL/Context.cpp | 9 | ||||
-rw-r--r-- | cpp/src/IceSSL/Context.h | 8 | ||||
-rw-r--r-- | cpp/src/IceSSL/OpenSSLPluginI.cpp | 6 | ||||
-rw-r--r-- | cpp/src/IceSSL/OpenSSLPluginI.h | 2 | ||||
-rw-r--r-- | cpp/src/IceSSL/ServerContext.cpp | 8 | ||||
-rw-r--r-- | cpp/src/IceSSL/ServerContext.h | 4 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslAcceptor.cpp | 2 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslClientTransceiver.cpp | 7 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslConnector.cpp | 6 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslServerTransceiver.cpp | 7 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslTransceiver.cpp | 51 | ||||
-rw-r--r-- | cpp/src/IceSSL/SslTransceiver.h | 5 |
14 files changed, 98 insertions, 27 deletions
diff --git a/cpp/src/IceSSL/ClientContext.cpp b/cpp/src/IceSSL/ClientContext.cpp index 95331f60d1c..da7ea42bcec 100644 --- a/cpp/src/IceSSL/ClientContext.cpp +++ b/cpp/src/IceSSL/ClientContext.cpp @@ -54,7 +54,7 @@ IceSSL::ClientContext::configure(const GeneralConfig& generalConfig, } IceSSL::SslTransceiverPtr -IceSSL::ClientContext::createTransceiver(int socket, const OpenSSLPluginIPtr& plugin) +IceSSL::ClientContext::createTransceiver(int socket, const OpenSSLPluginIPtr& plugin, int timeout) { if(_sslContext == 0) { @@ -66,7 +66,7 @@ IceSSL::ClientContext::createTransceiver(int socket, const OpenSSLPluginIPtr& pl SSL* ssl = createSSLConnection(socket); SslTransceiverPtr transceiver = new SslClientTransceiver(plugin, socket, _certificateVerifier, ssl); - transceiverSetup(transceiver); + transceiverSetup(transceiver, timeout); return transceiver; } @@ -80,6 +80,6 @@ IceSSL::ClientContext::ClientContext(const TraceLevelsPtr& traceLevels, const Lo _dsaPrivateKeyProperty = "IceSSL.Client.Overrides.DSA.PrivateKey"; _dsaPublicKeyProperty = "IceSSL.Client.Overrides.DSA.Certificate"; _caCertificateProperty = "IceSSL.Client.Overrides.CACertificate"; - _handshakeTimeoutProperty = "IceSSL.Client.Handshake.ReadTimeout"; _passphraseRetriesProperty = "IceSSL.Client.Passphrase.Retries"; + _connectionHandshakeRetries = "IceSSL.Client.Handshake.Retries"; } diff --git a/cpp/src/IceSSL/ClientContext.h b/cpp/src/IceSSL/ClientContext.h index 3e1308d52de..c093adfc51c 100644 --- a/cpp/src/IceSSL/ClientContext.h +++ b/cpp/src/IceSSL/ClientContext.h @@ -28,8 +28,8 @@ public: const CertificateAuthority&, const BaseCertificates&); - // Takes a socket fd as the first parameter. - virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&); + // Takes a socket fd as the first parameter, and the initial handshake timeout as the final. + virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&, int); protected: diff --git a/cpp/src/IceSSL/Context.cpp b/cpp/src/IceSSL/Context.cpp index 0d8da2ba191..611af358085 100644 --- a/cpp/src/IceSSL/Context.cpp +++ b/cpp/src/IceSSL/Context.cpp @@ -610,12 +610,13 @@ IceSSL::Context::createSSLConnection(int socket) } void -IceSSL::Context::transceiverSetup(const SslTransceiverPtr& transceiver) +IceSSL::Context::transceiverSetup(const SslTransceiverPtr& transceiver, int timeout) { - // Set the Post-Handshake Read timeout // This timeout is implemented once on the first read after hanshake. - int handshakeReadTimeout = _properties->getPropertyAsIntWithDefault(_handshakeTimeoutProperty, 5000); - transceiver->setHandshakeReadTimeout(handshakeReadTimeout); + transceiver->setHandshakeReadTimeout(timeout < 5000 ? 5000 : timeout); + + int retries = _properties->getPropertyAsIntWithDefault(_connectionHandshakeRetries, 10); + transceiver->setHandshakeRetries(retries); } void diff --git a/cpp/src/IceSSL/Context.h b/cpp/src/IceSSL/Context.h index cb392e83353..81c1ac94330 100644 --- a/cpp/src/IceSSL/Context.h +++ b/cpp/src/IceSSL/Context.h @@ -55,8 +55,8 @@ public: const CertificateAuthority&, const BaseCertificates&); - // Takes a socket fd as the first parameter. - virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&) = 0; + // Takes a socket fd as the first parameter, and the initial handshake timeout as the final. + virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&, int) = 0; protected: @@ -83,7 +83,7 @@ protected: SSL* createSSLConnection(int); - void transceiverSetup(const SslTransceiverPtr&); + void transceiverSetup(const SslTransceiverPtr&, int); void setCipherList(const std::string&); @@ -98,9 +98,9 @@ protected: std::string _dsaPrivateKeyProperty; std::string _dsaPublicKeyProperty; std::string _caCertificateProperty; - std::string _handshakeTimeoutProperty; std::string _passphraseRetriesProperty; std::string _maxPassphraseRetriesDefault; + std::string _connectionHandshakeRetries; CertificateVerifierPtr _certificateVerifier; diff --git a/cpp/src/IceSSL/OpenSSLPluginI.cpp b/cpp/src/IceSSL/OpenSSLPluginI.cpp index 3bfbba484bd..74096530822 100644 --- a/cpp/src/IceSSL/OpenSSLPluginI.cpp +++ b/cpp/src/IceSSL/OpenSSLPluginI.cpp @@ -192,7 +192,7 @@ IceSSL::OpenSSLPluginI::~OpenSSLPluginI() } IceSSL::SslTransceiverPtr -IceSSL::OpenSSLPluginI::createTransceiver(ContextType connectionType, int socket) +IceSSL::OpenSSLPluginI::createTransceiver(ContextType connectionType, int socket, int timeout) { IceUtil::RecMutex::Lock sync(_configMutex); @@ -215,11 +215,11 @@ IceSSL::OpenSSLPluginI::createTransceiver(ContextType connectionType, int socket if(connectionType == Client) { - transceiver = _clientContext.createTransceiver(socket, this); + transceiver = _clientContext.createTransceiver(socket, this, timeout); } else if(connectionType == Server) { - transceiver = _serverContext.createTransceiver(socket, this); + transceiver = _serverContext.createTransceiver(socket, this, timeout); } return transceiver; diff --git a/cpp/src/IceSSL/OpenSSLPluginI.h b/cpp/src/IceSSL/OpenSSLPluginI.h index 46c6f6a8b26..e9097bb5e9e 100644 --- a/cpp/src/IceSSL/OpenSSLPluginI.h +++ b/cpp/src/IceSSL/OpenSSLPluginI.h @@ -54,7 +54,7 @@ public: virtual ~OpenSSLPluginI(); - virtual SslTransceiverPtr createTransceiver(ContextType, int); + virtual SslTransceiverPtr createTransceiver(ContextType, int, int); virtual bool isConfigured(ContextType); virtual void configure(); diff --git a/cpp/src/IceSSL/ServerContext.cpp b/cpp/src/IceSSL/ServerContext.cpp index ae025a0af64..9d3ac620fcf 100644 --- a/cpp/src/IceSSL/ServerContext.cpp +++ b/cpp/src/IceSSL/ServerContext.cpp @@ -75,7 +75,7 @@ IceSSL::ServerContext::configure(const GeneralConfig& generalConfig, } IceSSL::SslTransceiverPtr -IceSSL::ServerContext::createTransceiver(int socket, const OpenSSLPluginIPtr& plugin) +IceSSL::ServerContext::createTransceiver(int socket, const OpenSSLPluginIPtr& plugin, int timeout) { if(_sslContext == 0) { @@ -87,7 +87,7 @@ IceSSL::ServerContext::createTransceiver(int socket, const OpenSSLPluginIPtr& pl SSL* ssl = createSSLConnection(socket); SslTransceiverPtr transceiver = new SslServerTransceiver(plugin, socket, _certificateVerifier, ssl); - transceiverSetup(transceiver); + transceiverSetup(transceiver, timeout); return transceiver; } @@ -105,8 +105,8 @@ IceSSL::ServerContext::ServerContext(const TraceLevelsPtr& traceLevels, const Lo _dsaPrivateKeyProperty = "IceSSL.Server.Overrides.DSA.PrivateKey"; _dsaPublicKeyProperty = "IceSSL.Server.Overrides.DSA.Certificate"; _caCertificateProperty = "IceSSL.Server.Overrides.CACertificate"; - _handshakeTimeoutProperty = "IceSSL.Server.Handshake.ReadTimeout"; - _passphraseRetriesProperty = "IceSSL.Client.Passphrase.Retries"; + _passphraseRetriesProperty = "IceSSL.Server.Passphrase.Retries"; + _connectionHandshakeRetries = "IceSSL.Server.Handshake.Retries"; } void diff --git a/cpp/src/IceSSL/ServerContext.h b/cpp/src/IceSSL/ServerContext.h index 5e802d84dbf..a4dec84af1a 100644 --- a/cpp/src/IceSSL/ServerContext.h +++ b/cpp/src/IceSSL/ServerContext.h @@ -28,8 +28,8 @@ public: const CertificateAuthority&, const BaseCertificates&); - // Takes a socket fd as the first parameter. - virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&); + // Takes a socket fd as the first parameter, and the initial handshake timeout as the final. + virtual SslTransceiverPtr createTransceiver(int, const OpenSSLPluginIPtr&, int); protected: diff --git a/cpp/src/IceSSL/SslAcceptor.cpp b/cpp/src/IceSSL/SslAcceptor.cpp index dc0b6665041..f0ff59b2d81 100644 --- a/cpp/src/IceSSL/SslAcceptor.cpp +++ b/cpp/src/IceSSL/SslAcceptor.cpp @@ -76,7 +76,7 @@ IceSSL::SslAcceptor::accept(int timeout) out << "accepted ssl connection\n" << fdToString(fd); } - return _plugin->createTransceiver(IceSSL::Server, fd); + return _plugin->createTransceiver(IceSSL::Server, fd, timeout); } string diff --git a/cpp/src/IceSSL/SslClientTransceiver.cpp b/cpp/src/IceSSL/SslClientTransceiver.cpp index 045b402a6f7..23730d8a242 100644 --- a/cpp/src/IceSSL/SslClientTransceiver.cpp +++ b/cpp/src/IceSSL/SslClientTransceiver.cpp @@ -211,6 +211,13 @@ IceSSL::SslClientTransceiver::handshake(int timeout) _initWantWrite = 0; } + if(_traceLevels->security >= IceSSL::SECURITY_PROTOCOL) + { + Trace out(_logger, _traceLevels->securityCat); + out << "Performing handshake.\n"; + out << fdToString(SSL_get_fd(_sslConnection)); + } + int result = connect(); switch(getLastError()) diff --git a/cpp/src/IceSSL/SslConnector.cpp b/cpp/src/IceSSL/SslConnector.cpp index 42f57f351c0..e05c56704b9 100644 --- a/cpp/src/IceSSL/SslConnector.cpp +++ b/cpp/src/IceSSL/SslConnector.cpp @@ -48,7 +48,11 @@ IceSSL::SslConnector::connect(int timeout) logger->trace(traceLevels->networkCat, s.str()); } - return _plugin->createTransceiver(IceSSL::Client, fd); + SslTransceiverPtr transceiver = _plugin->createTransceiver(IceSSL::Client, fd, timeout); + + transceiver->forceHandshake(); + + return transceiver; } string diff --git a/cpp/src/IceSSL/SslServerTransceiver.cpp b/cpp/src/IceSSL/SslServerTransceiver.cpp index 91c44421278..12d2feb403e 100644 --- a/cpp/src/IceSSL/SslServerTransceiver.cpp +++ b/cpp/src/IceSSL/SslServerTransceiver.cpp @@ -196,6 +196,13 @@ IceSSL::SslServerTransceiver::handshake(int timeout) } } + if(_traceLevels->security >= IceSSL::SECURITY_PROTOCOL) + { + Trace out(_logger, _traceLevels->securityCat); + out << "Performing handshake.\n"; + out << fdToString(SSL_get_fd(_sslConnection)); + } + int result = accept(); // We're doing an Accept and we don't get a retry on the socket. diff --git a/cpp/src/IceSSL/SslTransceiver.cpp b/cpp/src/IceSSL/SslTransceiver.cpp index 30ae5ba54b2..334fc554295 100644 --- a/cpp/src/IceSSL/SslTransceiver.cpp +++ b/cpp/src/IceSSL/SslTransceiver.cpp @@ -273,11 +273,60 @@ IceSSL::SslTransceiver::toString() const } void +IceSSL::SslTransceiver::forceHandshake() +{ + int retryCount = 0; + + while(retryCount < _handshakeRetries) + { + ++retryCount; + + try + { + if(handshake(_handshakeReadTimeout) > 0) + { + // Handshake complete. + break; + } + } + catch(TimeoutException) + { + // Do nothing. + } + } + + if(retryCount >= _handshakeRetries) + { + if(_traceLevels->security >= IceSSL::SECURITY_WARNINGS) + { + Trace out(_logger, _traceLevels->securityCat); + out << "Handshake retry maximum reached.\n"; + out << fdToString(SSL_get_fd(_sslConnection)); + } + + // If the handshake fails, the connection failed. + ConnectFailedException ex(__FILE__, __LINE__); +#ifdef _WIN32 + ex.error = WSAECONNREFUSED; +#else + ex.error = ECONNREFUSED; +#endif + throw ex; + } +} + +void IceSSL::SslTransceiver::setHandshakeReadTimeout(int timeout) { _handshakeReadTimeout = timeout; } +void +IceSSL::SslTransceiver::setHandshakeRetries(int retries) +{ + _handshakeRetries = retries; +} + IceSSL::SslTransceiverPtr IceSSL::SslTransceiver::getTransceiver(SSL* sslPtr) { @@ -958,8 +1007,8 @@ IceSSL::SslTransceiver::SslTransceiver(const OpenSSLPluginIPtr& plugin, _initWantRead = 0; _initWantWrite = 0; - // None configured, default to indicated timeout _handshakeReadTimeout = 0; + _handshakeRetries = 0; // Set up the SSL to be able to refer back to our connection object. addTransceiver(_sslConnection, this); diff --git a/cpp/src/IceSSL/SslTransceiver.h b/cpp/src/IceSSL/SslTransceiver.h index 7af6f5f2f37..160798759e8 100644 --- a/cpp/src/IceSSL/SslTransceiver.h +++ b/cpp/src/IceSSL/SslTransceiver.h @@ -138,8 +138,10 @@ public: virtual void read(IceInternal::Buffer&, int); virtual std::string toString() const; + void forceHandshake(); virtual int handshake(int timeout = 0) = 0; - void setHandshakeReadTimeout(int timeout); + void setHandshakeReadTimeout(int); + void setHandshakeRetries(int); static SslTransceiverPtr getTransceiver(SSL*); // Callback from OpenSSL for purposes of certificate verification @@ -195,6 +197,7 @@ protected: int _initWantRead; int _initWantWrite; int _handshakeReadTimeout; + int _handshakeRetries; int _readTimeout; ConnectPhase _phase; |