diff options
Diffstat (limited to 'cpp/src/IceSSL/SChannelTransceiverI.cpp')
-rw-r--r-- | cpp/src/IceSSL/SChannelTransceiverI.cpp | 501 |
1 files changed, 81 insertions, 420 deletions
diff --git a/cpp/src/IceSSL/SChannelTransceiverI.cpp b/cpp/src/IceSSL/SChannelTransceiverI.cpp index 8b3867d571e..108a130af12 100644 --- a/cpp/src/IceSSL/SChannelTransceiverI.cpp +++ b/cpp/src/IceSSL/SChannelTransceiverI.cpp @@ -208,26 +208,9 @@ getSecBufferWithType(const SecBufferDesc& desc, ULONG bufferType) IceInternal::NativeInfoPtr IceSSL::TransceiverI::getNativeInfo() { - return this; + return _stream; } -#ifdef ICE_USE_IOCP -IceInternal::AsyncInfo* -IceSSL::TransceiverI::getAsyncInfo(IceInternal::SocketOperation status) -{ - switch(status) - { - case IceInternal::SocketOperationRead: - return &_read; - case IceInternal::SocketOperationWrite: - return &_write; - default: - assert(false); - return 0; - } -} -#endif - IceInternal::SocketOperation IceSSL::TransceiverI::sslHandshake() { @@ -248,15 +231,18 @@ IceSSL::TransceiverI::sslHandshake() SECURITY_STATUS err = SEC_E_OK; DWORD ctxFlags = 0; - - while(true) + if(_state == StateHandshakeNotStarted) { - if(_state == StateConnected) + _readBuffer.b.resize(2048); + _readBuffer.i = _readBuffer.b.begin(); + _credentials = _engine->newCredentialsHandle(_incoming); + _credentialsInitialized = true; + + if(!_incoming) { - assert(!_incoming); SecBuffer outBuffer = { 0, SECBUFFER_TOKEN, 0 }; SecBufferDesc outBufferDesc = { SECBUFFER_VERSION, 1, &outBuffer }; - + err = InitializeSecurityContext(&_credentials, 0, const_cast<char *>(_host.c_str()), flags, 0, 0, 0, 0, &_ssl, &outBufferDesc, &ctxFlags, 0); _sslInitialized = true; @@ -265,7 +251,7 @@ IceSSL::TransceiverI::sslHandshake() throw SecurityException(__FILE__, __LINE__, "IceSSL: handshake failure:\n" + IceUtilInternal::lastErrorToString()); } - + // // Copy the data to the write buffer // @@ -273,10 +259,17 @@ IceSSL::TransceiverI::sslHandshake() _writeBuffer.i = _writeBuffer.b.begin(); memcpy(_writeBuffer.i, outBuffer.pvBuffer, outBuffer.cbBuffer); FreeContextBuffer(outBuffer.pvBuffer); - + _state = StateHandshakeWriteContinue; } + else + { + _state = StateHandshakeReadContinue; + } + } + while(true) + { if(_state == StateHandshakeReadContinue) { // If read buffer is empty, try to read some data. @@ -637,64 +630,13 @@ IceSSL::TransceiverI::encryptMessage(IceInternal::Buffer& buffer) IceInternal::SocketOperation IceSSL::TransceiverI::initialize(IceInternal::Buffer& readBuffer, IceInternal::Buffer& writeBuffer, bool& hasMoreData) { - if(_state == StateNeedConnect) - { - _state = StateConnectPending; - return IceInternal::SocketOperationConnect; - } - else if(_state <= StateConnectPending) - { - IceInternal::doFinishConnectAsync(_fd, _write); - - _desc = IceInternal::fdToString(_fd, _proxy, _addr, true); - - if(_proxy) - { - // - // Prepare the read & write buffers in advance. - // - _proxy->beginWriteConnectRequest(_addr, writeBuffer); - _proxy->beginReadConnectRequestResponse(readBuffer); - - // - // Return SocketOperationWrite to indicate we need to start a write. - // - _state = StateProxyConnectRequest; // Send proxy connect request - return IceInternal::SocketOperationWrite; - } - - _state = StateConnected; - } - else if(_state == StateProxyConnectRequest) - { - // - // Write completed. - // - _proxy->endWriteConnectRequest(writeBuffer); - _state = StateProxyConnectRequestPending; // Wait for proxy response - return IceInternal::SocketOperationRead; - } - else if(_state == StateProxyConnectRequestPending) - { - // - // Read completed. - // - _proxy->endReadConnectRequestResponse(readBuffer); - _state = StateConnected; - } - - assert(_state >= StateConnected && _state <= StateHandshakeWriteContinue); - - if(!_credentialsInitialized) + IceInternal::SocketOperation op = _stream->connect(readBuffer, writeBuffer); + if(op != IceInternal::SocketOperationNone) { - _readBuffer.b.resize(2048); - _readBuffer.i = _readBuffer.b.begin(); - - _credentials = _engine->newCredentialsHandle(_incoming); - _credentialsInitialized = true; + return op; } - - IceInternal::SocketOperation op = sslHandshake(); + + op = sslHandshake(); if(op != IceInternal::SocketOperationNone) { return op; @@ -776,7 +718,7 @@ IceSSL::TransceiverI::initialize(IceInternal::Buffer& readBuffer, IceInternal::B } } } - _engine->verifyPeer(_fd, _host, getNativeConnectionInfo()); + _engine->verifyPeer(_stream->fd(), _host, getNativeConnectionInfo()); _state = StateHandshakeComplete; if(_instance->engine()->securityTraceLevel() >= 1) @@ -805,7 +747,7 @@ IceSSL::TransceiverI::initialize(IceInternal::Buffer& readBuffer, IceInternal::B << "\nkey exchange = " << sslKeyExchangeAlgorithm << "\nprotocol = " << sslProtocolName << "\n"; } - out << IceInternal::fdToString(_fd); + out << toString(); } hasMoreData = !_readUnprocessed.b.empty() || _readBuffer.i != _readBuffer.b.begin(); return IceInternal::SocketOperationNone; @@ -832,22 +774,17 @@ IceSSL::TransceiverI::close() FreeCredentialsHandle(&_credentials); } - assert(_fd != INVALID_SOCKET); - try - { - IceInternal::closeSocket(_fd); - _fd = INVALID_SOCKET; - } - catch(const SocketException&) - { - _fd = INVALID_SOCKET; - throw; - } + _stream->close(); } IceInternal::SocketOperation IceSSL::TransceiverI::write(IceInternal::Buffer& buf) { + if(!_stream->isConnected()) + { + return _stream->write(buf); + } + if(buf.i == buf.b.end()) { return IceInternal::SocketOperationNone; @@ -878,6 +815,11 @@ IceSSL::TransceiverI::write(IceInternal::Buffer& buf) IceInternal::SocketOperation IceSSL::TransceiverI::read(IceInternal::Buffer& buf, bool& hasMoreData) { + if(!_stream->isConnected()) + { + return _stream->read(buf); + } + if(buf.i == buf.b.end()) { return IceInternal::SocketOperationNone; @@ -913,173 +855,74 @@ IceSSL::TransceiverI::read(IceInternal::Buffer& buf, bool& hasMoreData) bool IceSSL::TransceiverI::startWrite(IceInternal::Buffer& buffer) { - if(_state == StateConnectPending) + if(!_stream->isConnected()) { - IceInternal::Address addr = _proxy ? _proxy->getAddress() : _addr; - doConnectAsync(_fd, addr, _sourceAddr, _write); - return false; + return _stream->startWrite(buffer); } - IceInternal::Buffer& buf = _state == StateProxyConnectRequest ? buffer : _writeBuffer; - if(_state == StateHandshakeComplete && _bufferedW == 0) { assert(_writeBuffer.i == _writeBuffer.b.end()); - _bufferedW = encryptMessage(buf); + _bufferedW = encryptMessage(buffer); } - assert(buf.i != buf.b.end()); - int packetSize = static_cast<int>(buf.b.end() - buf.i); - if(_maxSendPacketSize > 0 && packetSize > _maxSendPacketSize) - { - packetSize = _maxSendPacketSize; - } - assert(packetSize > 0); - _write.buf.len = static_cast<DWORD>(packetSize); - _write.buf.buf = reinterpret_cast<char*>(buf.i); - int err = WSASend(_fd, &_write.buf, 1, &_write.count, 0, &_write, NULL); - if(err == SOCKET_ERROR) - { - if(!IceInternal::wouldBlock()) - { - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - } - } - return packetSize == static_cast<int>(buf.b.end() - buf.i); + return _stream->startWrite(_writeBuffer); } void IceSSL::TransceiverI::finishWrite(IceInternal::Buffer& buf) { - if(_state < StateConnected && _state != StateProxyConnectRequest) + if(!_stream->isConnected()) { + _stream->finishWrite(buf); return; } - if(static_cast<int>(_write.count) == SOCKET_ERROR) + _stream->finishWrite(_writeBuffer); + if(_writeBuffer.i != _writeBuffer.b.end()) { - WSASetLastError(_write.error); - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } + return; // We're not finished yet with writing the write buffer. } - if(_state == StateProxyConnectRequest) + if(_state == StateHandshakeComplete) { - buf.i += _write.count; - } - else - { - _writeBuffer.i += _write.count; - if(_writeBuffer.i == _writeBuffer.b.end()) - { - buf.i += _bufferedW; - _bufferedW = 0; - } + buf.i += _bufferedW; + _bufferedW = 0; } } void IceSSL::TransceiverI::startRead(IceInternal::Buffer& buffer) { - IceInternal::Buffer& buf = _state == StateProxyConnectRequest ? buffer : _readBuffer; - - int packetSize = static_cast<int>(buf.b.end() - buf.i); - if(_maxReceivePacketSize > 0 && packetSize > _maxReceivePacketSize) - { - packetSize = _maxReceivePacketSize; - } - assert(!buf.b.empty() && buf.i != buf.b.end()); - - _read.buf.len = static_cast<DWORD>(packetSize); - _read.buf.buf = reinterpret_cast<char*>(buf.i); - - int err = WSARecv(_fd, &_read.buf, 1, &_read.count, &_read.flags, &_read, NULL); - if(err == SOCKET_ERROR) + if(!_stream->isConnected()) { - if(!IceInternal::wouldBlock()) - { - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - } + _stream->startRead(buffer); + return; } + _stream->startRead(_readBuffer); } void IceSSL::TransceiverI::finishRead(IceInternal::Buffer& buf, bool& hasMoreData) { - if(static_cast<int>(_read.count) == SOCKET_ERROR) - { - WSASetLastError(_read.error); - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - } - else if(_read.count == 0) + if(!_stream->isConnected()) { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = 0; - throw ex; + _stream->finishRead(buf); + return; } - if(_state == StateProxyConnectRequest) + _stream->finishRead(_readBuffer); + if(_state == StateHandshakeComplete) { - buf.i += _read.count; - } - else - { - _readBuffer.i += _read.count; - if(_state == StateHandshakeComplete) + size_t decrypted = decryptMessage(buf); + if(decrypted > 0) { - size_t decrypted = decryptMessage(buf); - if(decrypted > 0) - { - buf.i += decrypted; - hasMoreData = !_readUnprocessed.b.empty() || _readBuffer.i != _readBuffer.b.begin(); - } - else - { - hasMoreData = false; - } + buf.i += decrypted; + hasMoreData = !_readUnprocessed.b.empty() || _readBuffer.i != _readBuffer.b.begin(); + } + else + { + hasMoreData = false; } } } @@ -1094,7 +937,7 @@ IceSSL::TransceiverI::protocol() const string IceSSL::TransceiverI::toString() const { - return _desc; + return _stream->toString(); } string @@ -1118,113 +961,36 @@ IceSSL::TransceiverI::checkSendSize(const IceInternal::Buffer& buf, size_t messa } } -IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SOCKET fd, const IceInternal::NetworkProxyPtr& proxy, - const string& host, const IceInternal::Address& addr, - const IceInternal::Address& sourceAddr) : - IceInternal::NativeInfo(fd), +IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, + const IceInternal::StreamSocketPtr& stream, + const string& host, + const string& adapterName) : _instance(instance), _engine(SChannelEnginePtr::dynamicCast(instance->engine())), - _proxy(proxy), _host(host), - _addr(addr), - _sourceAddr(sourceAddr), - _incoming(false), - _state(StateNeedConnect), - _writeBuffer(0), - _bufferedW(0), - _readBuffer(0), - _readUnprocessed(0), - _sslInitialized(false), - _credentialsInitialized(false) -#ifdef ICE_USE_IOCP - , _read(IceInternal::SocketOperationRead), - _write(IceInternal::SocketOperationWrite) -#endif -{ - IceInternal::setBlock(fd, false); - IceInternal::setTcpBufSize(fd, _instance->properties(), _instance->logger()); - - // - // On Windows, limiting the buffer size is important to prevent - // poor throughput performances when transfering large amount of - // data. See Microsoft KB article KB823764. - // - _maxSendPacketSize = IceInternal::getSendBufferSize(_fd) / 2; - if(_maxSendPacketSize < 512) - { - _maxSendPacketSize = 0; - } - - _maxReceivePacketSize = IceInternal::getRecvBufferSize(_fd); - if(_maxReceivePacketSize < 512) - { - _maxReceivePacketSize = 0; - } - -#ifndef ICE_USE_IOCP - IceInternal::Address connectAddr = proxy ? proxy->getAddress() : addr; - if(IceInternal::doConnect(_fd, connectAddr, _sourceAddr)) - { - _state = StateConnected; - _desc = IceInternal::fdToString(_fd, _proxy, _addr, true); - } - else - { - _desc = IceInternal::fdToString(_fd, _proxy, _addr, true); - } -#endif -} - -IceSSL::TransceiverI::TransceiverI(const InstancePtr& instance, SOCKET fd, const string& adapterName) : - IceInternal::NativeInfo(fd), - _instance(instance), - _engine(SChannelEnginePtr::dynamicCast(instance->engine())), _adapterName(adapterName), - _incoming(true), - _state(StateHandshakeReadContinue), - _desc(IceInternal::fdToString(fd)), + _incoming(host.empty()), + _stream(stream), + _state(StateHandshakeNotStarted), _writeBuffer(0), _bufferedW(0), _readBuffer(0), _readUnprocessed(0), _sslInitialized(false), _credentialsInitialized(false) -#ifdef ICE_USE_IOCP - , _read(IceInternal::SocketOperationRead), - _write(IceInternal::SocketOperationWrite) -#endif { - IceInternal::setBlock(fd, false); - IceInternal::setTcpBufSize(fd, _instance->properties(), _instance->logger()); - - // - // On Windows, limiting the buffer size is important to prevent - // poor throughput performances when transfering large amount of - // data. See Microsoft KB article KB823764. - // - _maxSendPacketSize = IceInternal::getSendBufferSize(_fd) / 2; - if(_maxSendPacketSize < 512) - { - _maxSendPacketSize = 0; - } - - _maxReceivePacketSize = IceInternal::getRecvBufferSize(_fd); - if(_maxReceivePacketSize < 512) - { - _maxReceivePacketSize = 0; - } } IceSSL::TransceiverI::~TransceiverI() { - assert(_fd == INVALID_SOCKET); } NativeConnectionInfoPtr IceSSL::TransceiverI::getNativeConnectionInfo() const { NativeConnectionInfoPtr info = new NativeConnectionInfo(); - IceInternal::fdToAddressAndPort(_fd, info->localAddress, info->localPort, info->remoteAddress, info->remotePort); + IceInternal::fdToAddressAndPort(_stream->fd(), info->localAddress, info->localPort, info->remoteAddress, + info->remotePort); if(_sslInitialized) { @@ -1282,121 +1048,16 @@ IceSSL::TransceiverI::getNativeConnectionInfo() const bool IceSSL::TransceiverI::writeRaw(IceInternal::Buffer& buf) { - int packetSize = static_cast<int>(buf.b.end() - buf.i); -#ifdef ICE_USE_IOCP - // - // Limit packet size to avoid performance problems on WIN32 - // - if(_maxSendPacketSize > 0 && packetSize > _maxSendPacketSize) - { - packetSize = _maxSendPacketSize; - } -#endif - - while(buf.i != buf.b.end()) - { - assert(_fd != INVALID_SOCKET); - - int ret = ::send(_fd, reinterpret_cast<const char*>(buf.i), packetSize, 0); - if(ret == 0) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = 0; - throw ex; - } - - if(ret == SOCKET_ERROR) - { - if(IceInternal::interrupted()) - { - continue; - } - - if(IceInternal::noBuffers() && packetSize > 1024) - { - packetSize /= 2; - continue; - } - - if(IceInternal::wouldBlock()) - { - return false; - } - - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - } - buf.i += ret; - if(packetSize > static_cast<int>(buf.b.end() - buf.i)) - { - packetSize = static_cast<int>(buf.b.end() - buf.i); - } - } - return true; + ssize_t ret = _stream->write(reinterpret_cast<const char*>(&*buf.i), buf.b.end() - buf.i); + buf.i += ret; + return buf.i == buf.b.end(); } bool IceSSL::TransceiverI::readRaw(IceInternal::Buffer& buf) { - assert(buf.i != buf.b.end()); - int packetSize = static_cast<int>(buf.b.end() - buf.i); - Byte* i = buf.i; - while(buf.i != buf.b.end()) - { - assert(_fd != INVALID_SOCKET); - ssize_t ret = ::recv(_fd, reinterpret_cast<char*>(buf.i), packetSize, 0); - - if(ret == 0) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = 0; - throw ex; - } - - if(ret == SOCKET_ERROR) - { - if(IceInternal::interrupted()) - { - continue; - } - - if(IceInternal::noBuffers() && packetSize > 1024) - { - packetSize /= 2; - continue; - } - - if(IceInternal::wouldBlock()) - { - return buf.i != i; - } - - if(IceInternal::connectionLost()) - { - ConnectionLostException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - else - { - SocketException ex(__FILE__, __LINE__); - ex.error = IceInternal::getSocketErrno(); - throw ex; - } - } - buf.i += ret; - packetSize = static_cast<int>(buf.b.end() - buf.i); - } - return buf.i != i; + ssize_t ret = _stream->read(reinterpret_cast<char*>(&*buf.i), buf.b.end() - buf.i); + buf.i += ret; + return ret > 0; } #endif |