From 53f55ff3ff65de2f9bd4410e7f245d0e26f29ca3 Mon Sep 17 00:00:00 2001 From: randomdan Date: Tue, 24 Dec 2013 18:09:08 +0000 Subject: Manage database connections on a per thread basis --- project2/sql/rdbmsDataSource.cpp | 162 +++++++++++++++++++++++++++++++-------- project2/sql/rdbmsDataSource.h | 38 +++++++-- project2/sql/sqlBulkLoad.cpp | 10 +-- project2/sql/sqlCache.cpp | 22 ++++-- project2/sql/sqlMergeTask.cpp | 6 +- project2/sql/sqlMergeTask.h | 2 +- project2/sql/sqlRows.cpp | 3 +- project2/sql/sqlTask.cpp | 3 +- project2/sql/sqlTest.cpp | 3 +- 9 files changed, 192 insertions(+), 57 deletions(-) diff --git a/project2/sql/rdbmsDataSource.cpp b/project2/sql/rdbmsDataSource.cpp index bf69dcd..7965a29 100644 --- a/project2/sql/rdbmsDataSource.cpp +++ b/project2/sql/rdbmsDataSource.cpp @@ -8,6 +8,7 @@ #include SimpleMessageException(UnknownConnectionProvider); +#define LOCK(l) std::lock_guard _lock##l(l) /// Specialized ElementLoader for instances of RdbmsDataSource; handles persistent DB connections class RdbmsDataSourceLoader : public ElementLoader::For { @@ -38,7 +39,8 @@ DECLARE_CUSTOM_LOADER("rdbmsdatasource", RdbmsDataSourceLoader); RdbmsDataSource::DBHosts RdbmsDataSource::dbhosts; RdbmsDataSource::FailedHosts RdbmsDataSource::failedhosts; -RdbmsDataSource::DSNSet RdbmsDataSource::changedDSNs; +RdbmsDataSource::ChangedDSNs RdbmsDataSource::changedDSNs; +std::mutex RdbmsDataSource::glock; RdbmsDataSource::RdbmsDataSource(ScriptNodePtr p) : DataSource(p), @@ -54,24 +56,31 @@ RdbmsDataSource::~RdbmsDataSource() { } -const DB::Connection & +RdbmsDataSource::ConnectionRef RdbmsDataSource::getWritable() const { + LOCK(ilock); ConnectionPtr master = connectTo(masterDsn); if (!master->txOpen) { master->connection->beginTx(); master->txOpen = true; } - changedDSNs.insert(name); - return *master->connection; + LOCK(glock); + changedDSNs.insert({name, std::this_thread::get_id()}); + return master.get(); } -const DB::Connection & +RdbmsDataSource::ConnectionRef RdbmsDataSource::getReadonly() const { - if (changedDSNs.find(name) != changedDSNs.end()) { - return *connectTo(masterDsn)->connection; + { + LOCK(glock); + if (changedDSNs.find({name, std::this_thread::get_id()}) != changedDSNs.end()) { + glock.unlock(); + return connectTo(masterDsn).get(); + } } + LOCK(ilock); if (localhost.length() == 0 && preferLocal) { struct utsname name; if (uname(&name)) { @@ -89,9 +98,9 @@ RdbmsDataSource::getReadonly() const if (ro == roDSNs.end()) { Logger()->messagef(LOG_INFO, "%s: No database host matches local host name (%s) Will use master DSN", __PRETTY_FUNCTION__, localhost.c_str()); - return *connectTo(masterDsn)->connection; + return connectTo(masterDsn).get(); } - return *connectTo(ro->second)->connection; + return connectTo(ro->second).get(); } catch (...) { // Failed to connect to a preferred DB... carry on and try the others... @@ -99,38 +108,55 @@ RdbmsDataSource::getReadonly() const } BOOST_FOREACH(ReadonlyDSNs::value_type db, roDSNs) { try { - return *connectTo(db.second)->connection; + return connectTo(db.second).get(); } catch (...) { } } - return *connectTo(masterDsn)->connection; + return connectTo(masterDsn).get(); } void RdbmsDataSource::commit() { - DBHosts::const_iterator m = dbhosts.find(masterDsn); - if (m != dbhosts.end() && m->second->txOpen) { - m->second->connection->commitTx(); - m->second->txOpen = false; + LOCK(ilock); + LOCK(glock); + auto masters = dbhosts.equal_range(masterDsn); + for (auto m = masters.first; m != masters.second; m++) { + if (m->second->threadId) { + } + if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) { + m->second->connection->commitTx(); + m->second->txOpen = false; + if (m->second->users == 0) { + m->second->threadId.reset(); + } + } } } void RdbmsDataSource::rollback() { - DBHosts::const_iterator m = dbhosts.find(masterDsn); - if (m != dbhosts.end() && m->second->txOpen) { - m->second->connection->rollbackTx(); - m->second->txOpen = false; + LOCK(ilock); + LOCK(glock); + auto masters = dbhosts.equal_range(masterDsn); + for (auto m = masters.first; m != masters.second; m++) { + if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) { + m->second->connection->rollbackTx(); + m->second->txOpen = false; + if (m->second->users == 0) { + m->second->threadId.reset(); + } + } } - changedDSNs.erase(name); + changedDSNs.erase({name, std::this_thread::get_id()}); } RdbmsDataSource::ConnectionPtr RdbmsDataSource::connectTo(const ConnectionInfo & dsn) { + LOCK(glock); FailedHosts::iterator dbf = failedhosts.find(dsn); if (dbf != failedhosts.end()) { if (time(NULL) - 20 > dbf->second.FailureTime) { @@ -140,23 +166,27 @@ RdbmsDataSource::connectTo(const ConnectionInfo & dsn) throw dbf->second; } } - DBHosts::const_iterator dbi = dbhosts.find(dsn); - if (dbi != dbhosts.end()) { - try { - dbi->second->connection->ping(); - dbi->second->touch(); - return dbi->second; - } - catch (...) { - // Connection in failed state - Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__); + auto dbis = dbhosts.equal_range(dsn); + for (auto dbi = dbis.first; dbi != dbis.second; dbi++) { + if (!dbi->second->threadId || *dbi->second->threadId == std::this_thread::get_id()) { + try { + dbi->second->connection->ping(); + dbi->second->threadId = std::this_thread::get_id(); + dbi->second->touch(); + return dbi->second; + } + catch (...) { + // Connection in failed state + Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__); + } } } try { ConnectionPtr db = ConnectionPtr(new RdbmsConnection(dsn.connect(), 300)); - dbhosts[dsn] = db; + db->threadId = std::this_thread::get_id(); db->touch(); + dbhosts.insert({dsn, db}); return db; } catch (const DB::ConnectionError & e) { @@ -171,6 +201,7 @@ RdbmsDataSource::RdbmsConnection::RdbmsConnection(const DB::Connection * con, ti connection(con), txOpen(false), lastUsedTime(0), + users(0), keepAliveTime(kat) { } @@ -187,10 +218,25 @@ RdbmsDataSource::RdbmsConnection::touch() const time(&lastUsedTime); } +void +RdbmsDataSource::RdbmsConnection::incRef() +{ + users += 1; +} + +void +RdbmsDataSource::RdbmsConnection::decRef() +{ + users -= 1; + if (users == 0 && !txOpen) { + threadId.reset(); + } +} + bool RdbmsDataSource::RdbmsConnection::isExpired() const { - return (time(NULL) > lastUsedTime + keepAliveTime); + return ((time(NULL) > lastUsedTime + keepAliveTime) && (users == 0)); } RdbmsDataSource::ConnectionInfo::ConnectionInfo(ScriptNodePtr node) : @@ -211,3 +257,55 @@ RdbmsDataSource::ConnectionInfo::operator<(const RdbmsDataSource::ConnectionInfo return ((typeId < other.typeId) || ((typeId == other.typeId) && (dsn < other.dsn))); } +RdbmsDataSource::ConnectionRef::ConnectionRef() : + conn(NULL) +{ +} + +RdbmsDataSource::ConnectionRef::ConnectionRef(RdbmsConnection * c) : + conn(c) +{ + if (conn) + conn->incRef(); +} + +RdbmsDataSource::ConnectionRef::ConnectionRef(const ConnectionRef & ref) : + conn(ref.conn) +{ + if (conn) + conn->incRef(); +} + +RdbmsDataSource::ConnectionRef::~ConnectionRef() +{ + if (conn) + conn->decRef(); +} + +void RdbmsDataSource::ConnectionRef::operator=(const ConnectionRef & ref) +{ + if (conn) + conn->decRef(); + conn = ref.conn; + if (conn) + conn->incRef(); +} + +const DB::Connection * +RdbmsDataSource::ConnectionRef::operator->() const +{ + return conn->connection; +} + +const DB::Connection & +RdbmsDataSource::ConnectionRef::operator*() const +{ + return *conn->connection; +} + +const DB::Connection * +RdbmsDataSource::ConnectionRef::get() const +{ + return conn->connection; +} + diff --git a/project2/sql/rdbmsDataSource.h b/project2/sql/rdbmsDataSource.h index 0497aab..0a1ddfd 100644 --- a/project2/sql/rdbmsDataSource.h +++ b/project2/sql/rdbmsDataSource.h @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include "dataSource.h" #include #include @@ -13,6 +15,8 @@ /// Project2 component to provide access to transactional RDBMS data sources class RdbmsDataSource : public DataSource { public: + class ConnectionRef; + class RdbmsConnection { public: RdbmsConnection(const DB::Connection * connection, time_t kat); @@ -22,12 +26,33 @@ class RdbmsDataSource : public DataSource { bool isExpired() const; const DB::Connection * const connection; bool txOpen; + boost::optional threadId; private: + friend class ConnectionRef; + friend class RdbmsDataSource; mutable time_t lastUsedTime; + void incRef(); + void decRef(); + unsigned int users; const time_t keepAliveTime; }; + class ConnectionRef { + public: + ConnectionRef(); + ConnectionRef(RdbmsConnection *); + ConnectionRef(const ConnectionRef &); + ~ConnectionRef(); + void operator=(const ConnectionRef &); + + const DB::Connection * operator->() const; + const DB::Connection & operator*() const; + const DB::Connection * get() const; + private: + RdbmsConnection * conn; + }; + class ConnectionInfo { public: ConnectionInfo(ScriptNodePtr); @@ -42,14 +67,14 @@ class RdbmsDataSource : public DataSource { typedef boost::shared_ptr ConnectionPtr; typedef std::map ReadonlyDSNs; // Map hostname to DSN string - typedef std::map DBHosts; // Map DSN strings to connections + typedef std::multimap DBHosts; // Map DSN strings to connections typedef std::map FailedHosts; // Map DSN strings to failures RdbmsDataSource(ScriptNodePtr p); ~RdbmsDataSource(); - const DB::Connection & getReadonly() const; - const DB::Connection & getWritable() const; + ConnectionRef getReadonly() const; + ConnectionRef getWritable() const; virtual void commit(); virtual void rollback(); @@ -61,11 +86,14 @@ class RdbmsDataSource : public DataSource { ReadonlyDSNs roDSNs; private: + mutable std::mutex ilock; + static std::mutex glock; mutable std::string localhost; static DBHosts dbhosts; static FailedHosts failedhosts; - typedef std::set DSNSet; - static DSNSet changedDSNs; + typedef std::pair ChangedDSN; + typedef std::set ChangedDSNs; + static ChangedDSNs changedDSNs; friend class RdbmsDataSourceLoader; }; diff --git a/project2/sql/sqlBulkLoad.cpp b/project2/sql/sqlBulkLoad.cpp index 8787d3e..813323c 100644 --- a/project2/sql/sqlBulkLoad.cpp +++ b/project2/sql/sqlBulkLoad.cpp @@ -25,12 +25,12 @@ class SqlBulkLoad : public Task { void execute(ExecContext * ec) const { - const DB::Connection & wdb = db->getWritable(); - wdb.beginBulkUpload(targetTable(ec), extras(ec)); + auto wdb = db->getWritable(); + wdb->beginBulkUpload(targetTable(ec), extras(ec)); ScopeObject tidy([]{}, - [&]{ wdb.endBulkUpload(NULL); }, - [&]{ wdb.endBulkUpload("Stack unwind in progress"); }); - stream->runStream(boost::bind(&DB::Connection::bulkUploadData, &wdb, _1, _2), ec); + [&]{ wdb->endBulkUpload(NULL); }, + [&]{ wdb->endBulkUpload("Stack unwind in progress"); }); + stream->runStream(boost::bind(&DB::Connection::bulkUploadData, wdb.get(), _1, _2), ec); } const Variable dataSource; diff --git a/project2/sql/sqlCache.cpp b/project2/sql/sqlCache.cpp index 66e9288..af59ee5 100644 --- a/project2/sql/sqlCache.cpp +++ b/project2/sql/sqlCache.cpp @@ -138,7 +138,8 @@ class SqlCache : public Cache { HeaderTable.c_str(),n.c_str(), f.c_str()); applyKeys(ec, boost::bind(appendKeyAnds, &sql, _1), ps); sql.appendf(" ORDER BY r.p2_cacheid DESC, r.p2_row"); - SelectPtr gh(db->getReadonly().newSelectCommand(sql)); + auto con = db->getReadonly(); + SelectPtr gh(con->newSelectCommand(sql)); unsigned int offset = 0; gh->bindParamT(offset++, time(NULL) - CacheLife); applyKeys(ec, boost::bind(bindKeyValues, gh.get(), &offset, _2), ps); @@ -181,7 +182,8 @@ class SqlCache : public Cache { sql.append(", ?"); } sql.appendf(")"); - ModifyPtr m(db->getWritable().newModifyCommand(sql)); + auto con = db->getWritable(); + ModifyPtr m(con->newModifyCommand(sql)); unsigned int offset = 0; m->bindParamI(offset++, row++); BOOST_FOREACH(const Values::value_type & a, attrs) { @@ -206,12 +208,13 @@ class SqlCache : public Cache { { Buffer sp; sp.appendf("SAVEPOINT sp%p", this); - ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); + auto con = db->getWritable(); + ModifyPtr s = ModifyPtr(con->newModifyCommand(sp)); s->execute(); // Header Buffer del; del.appendf("INSERT INTO %s(p2_time) VALUES(?)", HeaderTable.c_str()); - ModifyPtr h = ModifyPtr(db->getWritable().newModifyCommand(del)); + ModifyPtr h = ModifyPtr(con->newModifyCommand(del)); h->bindParamT(0, time(NULL)); h->execute(); // Record set header @@ -223,7 +226,7 @@ class SqlCache : public Cache { offset = 0; applyKeys(ec, boost::bind(appendKeyBinds, &sql, &offset), ps); sql.appendf(")"); - ModifyPtr m(db->getWritable().newModifyCommand(sql)); + ModifyPtr m(con->newModifyCommand(sql)); offset = 0; applyKeys(ec, boost::bind(bindKeyValues, m.get(), &offset, _2), ps); m->execute(); @@ -234,7 +237,8 @@ class SqlCache : public Cache { { Buffer sp; sp.appendf("RELEASE SAVEPOINT sp%p", this); - ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); + auto con = db->getWritable(); + ModifyPtr s = ModifyPtr(con->newModifyCommand(sp)); s->execute(); } @@ -242,7 +246,8 @@ class SqlCache : public Cache { { Buffer sp; sp.appendf("ROLLBACK TO SAVEPOINT sp%p", this); - ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); + auto con = db->getWritable(); + ModifyPtr s = ModifyPtr(con->newModifyCommand(sp)); s->execute(); } @@ -268,7 +273,8 @@ class CustomSqlCacheLoader : public ElementLoader::For { RdbmsDataSource * db = co->dataSource(SqlCache::DataSource); Buffer del; del.appendf("DELETE FROM %s WHERE p2_time < ?", SqlCache::HeaderTable.c_str()); - ModifyPtr m(db->getWritable().newModifyCommand(del)); + auto con = db->getWritable(); + ModifyPtr m(con->newModifyCommand(del)); m->bindParamT(0, time(NULL) - SqlCache::CacheLife); m->execute(); db->commit(); diff --git a/project2/sql/sqlMergeTask.cpp b/project2/sql/sqlMergeTask.cpp index 4356279..880a85e 100644 --- a/project2/sql/sqlMergeTask.cpp +++ b/project2/sql/sqlMergeTask.cpp @@ -118,7 +118,7 @@ SqlMergeTask::~SqlMergeTask() void SqlMergeTask::loadComplete(const CommonObjects * co) { - destdb = &co->dataSource(dataSource(NULL))->getWritable(); + destdb = co->dataSource(dataSource(NULL))->getWritable(); insCmd = insertCommand(); BOOST_FOREACH(const Sources::value_type & i, sources) { attach(i, insCmd); @@ -165,8 +165,8 @@ SqlMergeTask::execute(ExecContext * ec) const auto savepoint(stringf("sqlmerge_savepoint_%p", this)); destdb->savepoint(savepoint); ScopeObject SPHandler(NULL, - boost::bind(&DB::Connection::releaseSavepoint, destdb, savepoint), - boost::bind(&DB::Connection::rollbackToSavepoint, destdb, savepoint)); + boost::bind(&DB::Connection::releaseSavepoint, destdb.get(), savepoint), + boost::bind(&DB::Connection::rollbackToSavepoint, destdb.get(), savepoint)); createTempTable(); if (earlyKeys(NULL)) { createTempKey(); diff --git a/project2/sql/sqlMergeTask.h b/project2/sql/sqlMergeTask.h index 7fe385a..34fd7aa 100644 --- a/project2/sql/sqlMergeTask.h +++ b/project2/sql/sqlMergeTask.h @@ -72,7 +72,7 @@ class SqlMergeTask : public Task { public: Sources sources; - const DB::Connection * destdb; + RdbmsDataSource::ConnectionRef destdb; const Variable dataSource; const Table dtable; const Table dtablet; diff --git a/project2/sql/sqlRows.cpp b/project2/sql/sqlRows.cpp index 6570a60..1395db9 100644 --- a/project2/sql/sqlRows.cpp +++ b/project2/sql/sqlRows.cpp @@ -51,7 +51,8 @@ void SqlRows::execute(const Glib::ustring & filter, const RowProcessorCallback & rp, ExecContext * ec) const { unsigned int offset = 0; - auto select = SelectPtr(db->getReadonly().newSelectCommand(sqlCommand.getSqlFor(filter))); + auto con = db->getReadonly(); + auto select = SelectPtr(con->newSelectCommand(sqlCommand.getSqlFor(filter))); sqlCommand.bindParams(ec, select.get(), offset); SqlState ss(select); while (ss.query->fetch()) { diff --git a/project2/sql/sqlTask.cpp b/project2/sql/sqlTask.cpp index 80ea08e..eab00d5 100644 --- a/project2/sql/sqlTask.cpp +++ b/project2/sql/sqlTask.cpp @@ -35,8 +35,9 @@ SqlTask::loadComplete(const CommonObjects * co) void SqlTask::execute(ExecContext * ec) const { + auto con = db->getWritable(); boost::shared_ptr modify = boost::shared_ptr( - db->getWritable().newModifyCommand(sqlCommand.getSqlFor(filter(NULL)))); + con->newModifyCommand(sqlCommand.getSqlFor(filter(NULL)))); unsigned int offset = 0; sqlCommand.bindParams(ec, modify.get(), offset); if (modify->execute() == 0) { diff --git a/project2/sql/sqlTest.cpp b/project2/sql/sqlTest.cpp index 2723d8e..45178ef 100644 --- a/project2/sql/sqlTest.cpp +++ b/project2/sql/sqlTest.cpp @@ -90,8 +90,9 @@ class HandleDoCompare : public DB::HandleField { bool SqlTest::passes(ExecContext * ec) const { + auto con = db->getReadonly(); boost::shared_ptr query = boost::shared_ptr( - db->getWritable().newSelectCommand(sqlCommand.getSqlFor(filter(NULL)))); + con->newSelectCommand(sqlCommand.getSqlFor(filter(NULL)))); unsigned int offset = 0; sqlCommand.bindParams(ec, query.get(), offset); HandleDoCompare h(testValue(ec), testOp(ec)); -- cgit v1.2.3