diff options
Diffstat (limited to 'cpp/src/IceSSL/SChannelEngine.cpp')
-rw-r--r-- | cpp/src/IceSSL/SChannelEngine.cpp | 309 |
1 files changed, 287 insertions, 22 deletions
diff --git a/cpp/src/IceSSL/SChannelEngine.cpp b/cpp/src/IceSSL/SChannelEngine.cpp index 37bd41b2ff0..25c62a81ff0 100644 --- a/cpp/src/IceSSL/SChannelEngine.cpp +++ b/cpp/src/IceSSL/SChannelEngine.cpp @@ -7,8 +7,10 @@ // // ********************************************************************** -#include <IceSSL/SSLEngine.h> +#include <IceSSL/SChannelEngine.h> +#include <IceSSL/SChannelTransceiverI.h> #include <IceSSL/Plugin.h> +#include <IceSSL/Util.h> #include <Ice/LocalException.h> #include <Ice/Logger.h> @@ -19,19 +21,275 @@ #include <IceUtil/FileUtil.h> #include <Ice/UUID.h> +#include <wincrypt.h> + using namespace std; using namespace Ice; using namespace IceUtil; using namespace IceUtilInternal; using namespace IceSSL; -#ifdef ICE_USE_SCHANNEL - -Shared* IceSSL::upCast(IceSSL::SChannelEngine* p) { return p; } +Shared* SChannel::upCast(SChannel::SSLEngine* p) +{ + return p; +} namespace { +void +addMatchingCertificates(HCERTSTORE source, HCERTSTORE target, DWORD findType, const void* findParam) +{ + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(source, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, + findType, findParam, next))) + { + if(!CertAddCertificateContextToStore(target, next, CERT_STORE_ADD_ALWAYS, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + IceUtilInternal::lastErrorToString()); + } + } + } + while(next); +} + +vector<PCCERT_CONTEXT> +findCertificates(const string& location, const string& name, const string& value, vector<HCERTSTORE>& stores) +{ + DWORD storeLoc; + if(location == "CurrentUser") + { + storeLoc = CERT_SYSTEM_STORE_CURRENT_USER; + } + else + { + storeLoc = CERT_SYSTEM_STORE_LOCAL_MACHINE; + } + + HCERTSTORE store = CertOpenStore(CERT_STORE_PROV_SYSTEM, 0, 0, storeLoc, Ice::stringToWstring(name).c_str()); + if(!store) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: failed to open certificate store `" + name + + "':\n" + IceUtilInternal::lastErrorToString()); + } + + // + // Start with all of the certificates in the collection and filter as necessary. + // + // - If the value is "*", return all certificates. + // - Otherwise, search using key:value pairs. The following keys are supported: + // + // Issuer + // IssuerDN + // Serial + // Subject + // SubjectDN + // SubjectKeyId + // Thumbprint + // + // A value must be enclosed in single or double quotes if it contains whitespace. + // + HCERTSTORE tmpStore = 0; + try + { + if(value != "*") + { + if(value.find(':', 0) == string::npos) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: no key in `" + value + "'"); + } + size_t start = 0; + size_t pos; + while((pos = value.find(':', start)) != string::npos) + { + string field = IceUtilInternal::toUpper(IceUtilInternal::trim(value.substr(start, pos - start))); + if(field != "SUBJECT" && field != "SUBJECTDN" && field != "ISSUER" && field != "ISSUERDN" && + field != "THUMBPRINT" && field != "SUBJECTKEYID" && field != "SERIAL") + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: unknown key in `" + value + "'"); + } + + start = pos + 1; + while(start < value.size() && (value[start] == ' ' || value[start] == '\t')) + { + ++start; + } + + if(start == value.size()) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: missing argument in `" + value + "'"); + } + + string arg; + if(value[start] == '"' || value[start] == '\'') + { + size_t end = start; + ++end; + while(end < value.size()) + { + if(value[end] == value[start] && value[end - 1] != '\\') + { + break; + } + ++end; + } + if(end == value.size() || value[end] != value[start]) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: unmatched quote in `" + value + "'"); + } + ++start; + arg = value.substr(start, end - start); + start = end + 1; + } + else + { + size_t end = value.find_first_of(" \t", start); + if(end == string::npos) + { + arg = value.substr(start); + start = value.size(); + } + else + { + arg = value.substr(start, end - start); + start = end + 1; + } + } + + tmpStore = CertOpenStore(CERT_STORE_PROV_MEMORY, 0, 0, 0, 0); + if(!tmpStore) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + IceUtilInternal::lastErrorToString()); + } + + if(field == "SUBJECT" || field == "ISSUER") + { + const wstring argW = Ice::stringToWstring(arg); + DWORD findType = field == "SUBJECT" ? CERT_FIND_SUBJECT_STR : CERT_FIND_ISSUER_STR; + addMatchingCertificates(store, tmpStore, findType, argW.c_str()); + } + else if(field == "SUBJECTDN" || field == "ISSUERDN") + { + const wstring argW = Ice::stringToWstring(arg); + DWORD flags[] = { + CERT_OID_NAME_STR, + CERT_OID_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + CERT_OID_NAME_STR | CERT_NAME_STR_FORCE_UTF8_DIR_STR_FLAG, + CERT_OID_NAME_STR | CERT_NAME_STR_FORCE_UTF8_DIR_STR_FLAG | CERT_NAME_STR_REVERSE_FLAG + }; + for(size_t i = 0; i < sizeof(flags) / sizeof(DWORD); ++i) + { + DWORD length = 0; + if(!CertStrToNameW(X509_ASN_ENCODING, argW.c_str(), flags[i], 0, 0, &length, 0)) + { + throw PluginInitializationException( + __FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for `IceSSL.FindCert' property:\n" + + IceUtilInternal::lastErrorToString()); + } + + vector<BYTE> buffer(length); + if(!CertStrToNameW(X509_ASN_ENCODING, argW.c_str(), flags[i], 0, &buffer[0], &length, 0)) + { + throw PluginInitializationException( + __FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for `IceSSL.FindCert' property:\n" + + IceUtilInternal::lastErrorToString()); + } + + CERT_NAME_BLOB name = { length, &buffer[0] }; + + DWORD findType = field == "SUBJECTDN" ? CERT_FIND_SUBJECT_NAME : CERT_FIND_ISSUER_NAME; + addMatchingCertificates(store, tmpStore, findType, &name); + } + } + else if(field == "THUMBPRINT" || field == "SUBJECTKEYID") + { + vector<BYTE> buffer; + if(!parseBytes(arg, buffer)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid `IceSSL.FindCert' property: can't decode the value"); + } + + CRYPT_HASH_BLOB hash = { static_cast<DWORD>(buffer.size()), &buffer[0] }; + DWORD findType = field == "THUMBPRINT" ? CERT_FIND_HASH : CERT_FIND_KEY_IDENTIFIER; + addMatchingCertificates(store, tmpStore, findType, &hash); + } + else if(field == "SERIAL") + { + vector<BYTE> buffer; + if(!parseBytes(arg, buffer)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for `IceSSL.FindCert' property"); + } + + CRYPT_INTEGER_BLOB serial = { static_cast<DWORD>(buffer.size()), &buffer[0] }; + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, + CERT_FIND_ANY, 0, next))) + { + if(CertCompareIntegerBlob(&serial, &next->pCertInfo->SerialNumber)) + { + if(!CertAddCertificateContextToStore(tmpStore, next, CERT_STORE_ADD_ALWAYS, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + + IceUtilInternal::lastErrorToString()); + } + } + } + } + while(next); + } + CertCloseStore(store, 0); + store = tmpStore; + } + } + } + catch(...) + { + if(store && store != tmpStore) + { + CertCloseStore(store, 0); + } + + if(tmpStore) + { + CertCloseStore(tmpStore, 0); + tmpStore = 0; + } + throw; + } + + vector<PCCERT_CONTEXT> certs; + if(store) + { + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, 0, + next))) + { + certs.push_back(next); + } + } + while(next); + stores.push_back(store); + } + return certs; +} + + #if defined(__MINGW32__) || (defined(_MSC_VER) && (_MSC_VER <= 1500)) // // CERT_CHAIN_ENGINE_CONFIG struct in mingw headers doesn't include @@ -187,16 +445,15 @@ algorithmId(const string& name) } -SChannelEngine::SChannelEngine(const CommunicatorPtr& communicator) : - SSLEngine(communicator), - _initialized(false), +SChannel::SSLEngine::SSLEngine(const CommunicatorPtr& communicator) : + IceSSL::SSLEngine(communicator), _rootStore(0), _chainEngine(0) { } void -SChannelEngine::initialize() +SChannel::SSLEngine::initialize() { Mutex::Lock lock(_mutex); if(_initialized) @@ -204,7 +461,7 @@ SChannelEngine::initialize() return; } - SSLEngine::initialize(); + IceSSL::SSLEngine::initialize(); const string prefix = "IceSSL."; const PropertiesPtr properties = communicator()->getProperties(); @@ -643,7 +900,7 @@ SChannelEngine::initialize() } string -SChannelEngine::getCipherName(ALG_ID cipher) const +SChannel::SSLEngine::getCipherName(ALG_ID cipher) const { switch(cipher) { @@ -678,15 +935,8 @@ SChannelEngine::getCipherName(ALG_ID cipher) const } } -bool -SChannelEngine::initialized() const -{ - Mutex::Lock lock(_mutex); - return _initialized; -} - CredHandle -SChannelEngine::newCredentialsHandle(bool incoming) +SChannel::SSLEngine::newCredentialsHandle(bool incoming) { SCHANNEL_CRED cred; memset(&cred, 0, sizeof(cred)); @@ -745,13 +995,13 @@ SChannelEngine::newCredentialsHandle(bool incoming) } HCERTCHAINENGINE -SChannelEngine::chainEngine() const +SChannel::SSLEngine::chainEngine() const { return _chainEngine; } void -SChannelEngine::parseCiphers(const std::string& ciphers) +SChannel::SSLEngine::parseCiphers(const std::string& ciphers) { vector<string> tokens; splitString(ciphers, " \t", tokens); @@ -766,7 +1016,7 @@ SChannelEngine::parseCiphers(const std::string& ciphers) } void -SChannelEngine::destroy() +SChannel::SSLEngine::destroy() { if(_chainEngine && _chainEngine != HCCE_CURRENT_USER && _chainEngine != HCCE_LOCAL_MACHINE) { @@ -809,4 +1059,19 @@ SChannelEngine::destroy() CertCloseStore(*i, 0); } } -#endif + +void +SChannel::SSLEngine::verifyPeer(const string& address, const NativeConnectionInfoPtr& info, const string& desc) +{ + verifyPeerCertName(address, info); + IceSSL::SSLEngine::verifyPeer(address, info, desc); +} + +IceInternal::TransceiverPtr +SChannel::SSLEngine::createTransceiver(const InstancePtr& instance, + const IceInternal::TransceiverPtr& delegate, + const string& hostOrAdapterName, + bool incoming) +{ + return new SChannel::TransceiverI(instance, delegate, hostOrAdapterName, incoming); +} |