diff options
author | Jose <jose@zeroc.com> | 2014-08-16 00:19:02 +0200 |
---|---|---|
committer | Jose <jose@zeroc.com> | 2014-08-16 00:19:02 +0200 |
commit | bdc771e3b1e1ccaa7db2510f3872c284bf292e8c (patch) | |
tree | 821dcdf747be1ba5b8c44b86b59437ceaaabfef0 /cpp/src/IceSSL/SChannelEngine.cpp | |
parent | Make ant builder watch test slice files (diff) | |
download | ice-bdc771e3b1e1ccaa7db2510f3872c284bf292e8c.tar.bz2 ice-bdc771e3b1e1ccaa7db2510f3872c284bf292e8c.tar.xz ice-bdc771e3b1e1ccaa7db2510f3872c284bf292e8c.zip |
ICE-5592 - IceSSL.FindCert (SChannel Windows C++ impl)
Diffstat (limited to 'cpp/src/IceSSL/SChannelEngine.cpp')
-rw-r--r-- | cpp/src/IceSSL/SChannelEngine.cpp | 442 |
1 files changed, 401 insertions, 41 deletions
diff --git a/cpp/src/IceSSL/SChannelEngine.cpp b/cpp/src/IceSSL/SChannelEngine.cpp index dee5965d975..c5d99dbb344 100644 --- a/cpp/src/IceSSL/SChannelEngine.cpp +++ b/cpp/src/IceSSL/SChannelEngine.cpp @@ -21,15 +21,40 @@ using namespace std; using namespace Ice; using namespace IceUtil; +using namespace IceUtilInternal; using namespace IceSSL; #ifdef ICE_USE_SCHANNEL -IceUtil::Shared* IceSSL::upCast(IceSSL::SChannelEngine* p) { return p; } +Shared* IceSSL::upCast(IceSSL::SChannelEngine* p) { return p; } namespace { +# ifdef __MINGW32__ +// +// CERT_CHAIN_ENGINE_CONFIG struct in mingw headers doesn't include +// new members added in Windows 7, we add our ouwn definition and +// then cast it to CERT_CHAIN_ENGINE_CONFIG this works because the +// linked libraries include the new version. +// +struct CertChainEngineConfig +{ + DWORD cbSize; + HCERTSTORE hRestrictedRoot; + HCERTSTORE hRestrictedTrust; + HCERTSTORE hRestrictedOther; + DWORD cAdditionalStore; + HCERTSTORE *rghAdditionalStore; + DWORD dwFlags; + DWORD dwUrlRetrievalTimeout; + DWORD MaximumCachedCertificates; + DWORD CycleDetectionModulus; + HCERTSTORE hExclusiveRoot; + HCERTSTORE hExclusiveTrustedPeople; +}; +# endif + void addCertificateToStore(const string& file, HCERTSTORE store, PCCERT_CONTEXT* cert = 0) { @@ -47,7 +72,7 @@ addCertificateToStore(const string& file, HCERTSTORE store, PCCERT_CONTEXT* cert // assert(GetLastError() != ERROR_MORE_DATA); throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding certificate:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding certificate:\n" + lastErrorToString()); } if(!CertAddEncodedCertificateToStore(store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, &outBuffer[0], @@ -56,7 +81,7 @@ addCertificateToStore(const string& file, HCERTSTORE store, PCCERT_CONTEXT* cert if(GetLastError() != static_cast<DWORD>(CRYPT_E_EXISTS)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding certificate:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding certificate:\n" + lastErrorToString()); } } } @@ -99,7 +124,7 @@ parseProtocols(const StringSeq& protocols) return v; } -const ALG_ID supportedChipers[] = {CALG_3DES, CALG_AES_128, CALG_AES_256, CALG_DES, CALG_RC2, CALG_RC4}; +const ALG_ID supportedChipers[] = { CALG_3DES, CALG_AES_128, CALG_AES_256, CALG_DES, CALG_RC2, CALG_RC4 }; const int supportedChipersSize = sizeof(supportedChipers)/sizeof(ALG_ID); ALG_ID @@ -132,6 +157,312 @@ algorithmId(const string& name) return 0; } +// +// Parse a string of the form "location.name" into two parts. +// +void +parseStore(const string& prop, const string& store, DWORD& loc, string& sname) +{ + size_t pos = store.find('.'); + if(pos == string::npos) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: property `" + prop + "' has invalid format"); + } + + const string sloc = toUpper(store.substr(0, pos)); + if(sloc == "CURRENTUSER") + { + loc = CERT_SYSTEM_STORE_CURRENT_USER; + } + else if(sloc == "LOCALMACHINE") + { + loc = CERT_SYSTEM_STORE_LOCAL_MACHINE; + } + else + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: unknown store location `" + sloc + "' in " + prop); + } + + sname = store.substr(pos + 1); + if(sname.empty()) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: invalid store name in " + prop); + } +} + + +bool +parseBytes(const string& arg, vector<BYTE>& buffer) +{ + string v = toUpper(arg); + + // + // Check for any invalid characters. + // + size_t pos = v.find_first_not_of(" :0123456789ABCDEF"); + if(pos != string::npos) + { + return false; + } + + // + // Remove any separator characters. + // + ostringstream s; + for(string::const_iterator i = v.begin(); i != v.end(); ++i) + { + if(*i == ' ' || *i == ':') + { + continue; + } + s << *i; + } + v = s.str(); + + // + // Convert the bytes. + // + for(size_t i = 0, length = v.size(); i + 2 <= length;) + { + buffer.push_back(static_cast<BYTE>(strtol(v.substr(i, 2).c_str(), 0, 16))); + i += 2; + } + return true; +} + +void +addMatchingCertificates(HCERTSTORE source, HCERTSTORE target, DWORD findType, const void* findParam) +{ + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(source, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, + findType, findParam, next))) + { + if(!CertAddCertificateContextToStore(target, next, CERT_STORE_ADD_ALWAYS, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + lastErrorToString()); + } + } + } + while(next); +} + +vector<PCCERT_CONTEXT> +findCertificates(const string& prop, const string& storeSpec, const string& value, vector<HCERTSTORE>& stores) +{ + DWORD storeLoc = 0; + string storeName; + parseStore(prop, storeSpec, storeLoc, storeName); + + HCERTSTORE store = CertOpenStore(CERT_STORE_PROV_SYSTEM, 0, 0, storeLoc, stringToWstring(storeName).c_str()); + if(!store) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: failure while opening store specified by " + prop + ":\n" + lastErrorToString()); + } + + // + // Start with all of the certificates in the collection and filter as necessary. + // + // - If the value is "*", return all certificates. + // - Otherwise, search using key:value pairs. The following keys are supported: + // + // Issuer + // IssuerDN + // Serial + // Subject + // SubjectDN + // SubjectKeyId + // Thumbprint + // + // A value must be enclosed in single or double quotes if it contains whitespace. + // + + HCERTSTORE tmpStore = 0; + try + { + if(value != "*") + { + size_t start = 0; + size_t pos; + while((pos = value.find(':', start)) != string::npos) + { + string field = toUpper(trim(value.substr(start, pos - start))); + if(field != "SUBJECT" && field != "SUBJECTDN" && field != "ISSUER" && field != "ISSUERDN" && + field != "THUMBPRINT" && field != "SUBJECTKEYID" && field != "SERIAL") + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: unknown key in `" + value + "'"); + } + + start = pos + 1; + while(start < value.size() && (value[start] == ' ' || value[start] == '\t')) + { + ++start; + } + + if(start == value.size()) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: missing argument in `" + value + "'"); + } + + string arg; + if(value[start] == '"' || value[start] == '\'') + { + size_t end = start; + ++end; + while(end < value.size()) + { + if(value[end] == value[start] && value[end - 1] != '\\') + { + break; + } + ++end; + } + if(end == value.size() || value[end] != value[start]) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: unmatched quote in `" + value + "'"); + } + ++start; + arg = value.substr(start, end - start); + start = end + 1; + } + else + { + size_t end = value.find_first_of(" \t", start); + if(end == string::npos) + { + arg = value.substr(start); + start = value.size(); + } + else + { + arg = value.substr(start, end - start); + start = end + 1; + } + } + + DWORD findType = 0; + + tmpStore = CertOpenStore(CERT_STORE_PROV_MEMORY, 0, 0, 0, 0); + if(!tmpStore) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + lastErrorToString()); + } + + if(field == "SUBJECT" || field == "ISSUER") + { + const wstring argW = stringToWstring(arg); + DWORD findType = field == "SUBJECT" ? CERT_FIND_SUBJECT_STR : CERT_FIND_ISSUER_STR; + addMatchingCertificates(store, tmpStore, findType, argW.c_str()); + } + else if(field == "SUBJECTDN" || field == "ISSUERDN") + { + const wstring argW = stringToWstring(arg); + DWORD length = 0; + if(!CertStrToNameW(X509_ASN_ENCODING, argW.c_str(), CERT_OID_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + 0, 0, &length, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for property `" + prop + "'\n" + + lastErrorToString()); + } + + vector<BYTE> buffer(length); + if(!CertStrToNameW(X509_ASN_ENCODING, argW.c_str(), CERT_OID_NAME_STR | CERT_NAME_STR_REVERSE_FLAG, + 0, &buffer[0], &length, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for property `" + prop + "'\n" + + lastErrorToString()); + } + + CERT_NAME_BLOB name = { length, &buffer[0] }; + DWORD findType = field == "SUBJECTDN" ? CERT_FIND_SUBJECT_NAME : CERT_FIND_ISSUER_NAME; + addMatchingCertificates(store, tmpStore, findType, &name); + } + else if(field == "THUMBPRINT" || field == "SUBJECTKEYID") + { + vector<BYTE> buffer; + if(!parseBytes(arg, buffer)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for property `" + prop + "'"); + } + + CRYPT_HASH_BLOB hash = { buffer.size(), &buffer[0] }; + DWORD findType = field == "THUMBPRINT" ? CERT_FIND_HASH : CERT_FIND_KEY_IDENTIFIER; + addMatchingCertificates(store, tmpStore, findType, &hash); + } + else if(field == "SERIAL") + { + vector<BYTE> buffer; + if(!parseBytes(arg, buffer)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: invalid value `" + value + "' for property `" + prop + "'"); + } + + CRYPT_INTEGER_BLOB serial = { buffer.size(), &buffer[0] }; + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, + CERT_FIND_ANY, 0, next))) + { + if(CertCompareIntegerBlob(&serial, &next->pCertInfo->SerialNumber)) + { + if(!CertAddCertificateContextToStore(tmpStore, next, CERT_STORE_ADD_ALWAYS, 0)) + { + throw PluginInitializationException(__FILE__, __LINE__, + "IceSSL: error adding certificate to store:\n" + lastErrorToString()); + } + } + } + } + while(next); + } + CertCloseStore(store, 0); + store = tmpStore; + } + } + } + catch(...) + { + if(store && store != tmpStore) + { + CertCloseStore(store, 0); + } + + if(tmpStore) + { + CertCloseStore(tmpStore, 0); + tmpStore = 0; + } + throw; + } + + vector<PCCERT_CONTEXT> certs; + if(store) + { + PCCERT_CONTEXT next = 0; + do + { + if((next = CertFindCertificateInStore(store, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 0, CERT_FIND_ANY, 0, + next))) + { + certs.push_back(next); + } + } + while(next); + stores.push_back(store); + } + return certs; +} + } SChannelEngine::SChannelEngine(const CommunicatorPtr& communicator) : @@ -145,7 +476,7 @@ SChannelEngine::SChannelEngine(const CommunicatorPtr& communicator) : void SChannelEngine::initialize() { - IceUtil::Mutex::Lock lock(_mutex); + Mutex::Lock lock(_mutex); if(_initialized) { return; @@ -215,7 +546,7 @@ SChannelEngine::initialize() if(!_rootStore) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error creating in memory certificate store:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error creating in memory certificate store:\n" + lastErrorToString()); } if(!checkPath(caFile, defaultDir, false)) @@ -256,7 +587,7 @@ SChannelEngine::initialize() #endif { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error creating certificate chain engine:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error creating certificate chain engine:\n" + lastErrorToString()); } } else @@ -282,14 +613,14 @@ SChannelEngine::initialize() if(!certFile.empty()) { vector<string> certFiles; - if(!IceUtilInternal::splitString(certFile, IceUtilInternal::pathsep, certFiles) || certFiles.size() > 2) + if(!splitString(certFile, IceUtilInternal::pathsep, certFiles) || certFiles.size() > 2) { throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: invalid value for " + prefix + "CertFile:\n" + certFile); } vector<string> keyFiles; - if(!IceUtilInternal::splitString(keyFile, IceUtilInternal::pathsep, keyFiles) || keyFiles.size() > 2) + if(!splitString(keyFile, IceUtilInternal::pathsep, keyFiles) || keyFiles.size() > 2) { throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: invalid value for " + prefix + "KeyFile:\n" + keyFile); @@ -336,7 +667,7 @@ SChannelEngine::initialize() if(!cert) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: certificate error:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: certificate error:\n" + lastErrorToString()); } _certs.push_back(cert); continue; @@ -347,7 +678,7 @@ SChannelEngine::initialize() if(err != CRYPT_E_BAD_ENCODE) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding certificate:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding certificate:\n" + lastErrorToString()); } // @@ -373,7 +704,7 @@ SChannelEngine::initialize() &outBuffer[0], &outLength, 0, 0)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding key:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding key:\n" + lastErrorToString()); } PCRYPT_PRIVATE_KEY_INFO keyInfo = 0; @@ -387,7 +718,7 @@ SChannelEngine::initialize() CRYPT_DECODE_ALLOC_FLAG, 0, &keyInfo, &decodedLength)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding key:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding key:\n" + lastErrorToString()); } // @@ -402,7 +733,7 @@ SChannelEngine::initialize() // // Create a new RSA key set to store our key // - const wstring keySetName = stringToWstring(IceUtil::generateUUID()); + const wstring keySetName = stringToWstring(generateUUID()); HCRYPTPROV cryptProv = 0; DWORD contextFlags = (keySet == "MachineKeySet") ? CRYPT_MACHINE_KEYSET | CRYPT_NEWKEYSET : @@ -411,7 +742,7 @@ SChannelEngine::initialize() if(!CryptAcquireContextW(&cryptProv, keySetName.c_str(), MS_DEF_PROV_W, PROV_RSA_FULL, contextFlags)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error acquiring cryptographic context:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error acquiring cryptographic context:\n" + lastErrorToString()); } // @@ -422,7 +753,7 @@ SChannelEngine::initialize() CRYPT_DECODE_ALLOC_FLAG, 0, &key, &outLength)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error decoding key:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error decoding key:\n" + lastErrorToString()); } LocalFree(keyInfo); keyInfo = 0; @@ -433,7 +764,7 @@ SChannelEngine::initialize() if(!CryptImportKey(cryptProv, key, outLength, 0, 0, &hKey)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error importing key:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error importing key:\n" + lastErrorToString()); } LocalFree(key); key = 0; @@ -448,7 +779,7 @@ SChannelEngine::initialize() if(!store) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error creating certificate store:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error creating certificate store:\n" + lastErrorToString()); } addCertificateToStore(certFile, store, &cert); @@ -466,7 +797,7 @@ SChannelEngine::initialize() if(!CertSetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo)) { throw PluginInitializationException(__FILE__, __LINE__, - "IceSSL: error seting certificate property:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: error seting certificate property:\n" + lastErrorToString()); } _certs.push_back(cert); @@ -501,6 +832,31 @@ SChannelEngine::initialize() throw; } } + + _allCerts.insert(_allCerts.end(), _certs.begin(), _certs.end()); + } + + const string findPrefix = prefix + "FindCert."; + map<string, string> certProps = properties->getPropertiesForPrefix(findPrefix); + if(!certProps.empty()) + { + for(map<string, string>::const_iterator i = certProps.begin(); i != certProps.end(); ++i) + { + const string name = i->first; + const string val = i->second; + + if(!val.empty()) + { + string storeSpec = name.substr(findPrefix.size()); + vector<PCCERT_CONTEXT> certs = findCertificates(name, storeSpec, val, _stores); + _allCerts.insert(_allCerts.end(), certs.begin(), certs.end()); + } + } + + if(_allCerts.empty()) + { + throw PluginInitializationException(__FILE__, __LINE__, "IceSSL: no certificates found"); + } } _initialized = true; } @@ -544,7 +900,7 @@ SChannelEngine::getCipherName(ALG_ID cipher) const bool SChannelEngine::initialized() const { - IceUtil::Mutex::Lock lock(_mutex); + Mutex::Lock lock(_mutex); return _initialized; } @@ -555,10 +911,10 @@ SChannelEngine::newCredentialsHandle(bool incoming) memset(&cred, 0, sizeof(cred)); cred.dwVersion = SCHANNEL_CRED_VERSION; - if(!_certs.empty()) + if(!_allCerts.empty()) { - cred.cCreds = static_cast<DWORD>(_certs.size()); - cred.paCred = &_certs[0]; + cred.cCreds = static_cast<DWORD>(_allCerts.size()); + cred.paCred = &_allCerts[0]; } cred.grbitEnabledProtocols = _protocols; @@ -593,7 +949,7 @@ SChannelEngine::newCredentialsHandle(bool incoming) if(err != SEC_E_OK) { throw SecurityException(__FILE__, __LINE__, - "IceSSL: failed to acquire credentials handle:\n" + IceUtilInternal::lastErrorToString()); + "IceSSL: failed to acquire credentials handle:\n" + lastErrorToString()); } return credHandle; } @@ -608,7 +964,7 @@ void SChannelEngine::parseCiphers(const std::string& ciphers) { vector<string> tokens; - IceUtilInternal::splitString(ciphers, " \t", tokens); + splitString(ciphers, " \t", tokens); for(vector<string>::const_iterator i = tokens.begin(); i != tokens.end(); ++i) { ALG_ID id = algorithmId(*i); @@ -632,32 +988,36 @@ SChannelEngine::destroy() CertCloseStore(_rootStore, 0); } - for(vector<PCCERT_CONTEXT>::const_iterator i = _certs.begin(); i != _certs.end(); ++i) + for(vector<PCCERT_CONTEXT>::const_iterator i = _allCerts.begin(); i != _allCerts.end(); ++i) { PCCERT_CONTEXT cert = *i; - UniquePtr<CRYPT_KEY_PROV_INFO> keyProvInfo; - DWORD size = 0; - // - // Retrieve the certificate CERT_KEY_PROV_INFO_PROP_ID property, we use the CRYPT_KEY_PROV_INFO - // data to then remove the key set associated with the certificate. + // Only remove the keysets we create. // - if(CertGetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &size)) + if(find(_certs.begin(), _certs.end(), cert) != _certs.end()) { - vector<char> buf(size); - if(CertGetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, &buf[0], &size)) + // + // Retrieve the certificate CERT_KEY_PROV_INFO_PROP_ID property, we use the CRYPT_KEY_PROV_INFO + // data to then remove the key set associated with the certificate. + // + DWORD length = 0; + if(CertGetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &length)) { - CRYPT_KEY_PROV_INFO* keyProvInfo = reinterpret_cast<CRYPT_KEY_PROV_INFO*>(&buf[0]); - HCRYPTPROV cryptProv = 0; - if(CryptAcquireContextW(&cryptProv, keyProvInfo->pwszContainerName, keyProvInfo->pwszProvName, - keyProvInfo->dwProvType, 0)) + vector<char> buf(length); + if(CertGetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, &buf[0], &length)) { - CryptAcquireContextW(&cryptProv, keyProvInfo->pwszContainerName, keyProvInfo->pwszProvName, - keyProvInfo->dwProvType, CRYPT_DELETEKEYSET); + CRYPT_KEY_PROV_INFO* keyProvInfo = reinterpret_cast<CRYPT_KEY_PROV_INFO*>(&buf[0]); + HCRYPTPROV cryptProv = 0; + if(CryptAcquireContextW(&cryptProv, keyProvInfo->pwszContainerName, keyProvInfo->pwszProvName, + keyProvInfo->dwProvType, 0)) + { + CryptAcquireContextW(&cryptProv, keyProvInfo->pwszContainerName, keyProvInfo->pwszProvName, + keyProvInfo->dwProvType, CRYPT_DELETEKEYSET); + } } + CertFreeCertificateContext(cert); } - CertFreeCertificateContext(cert); } } |