summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--project2/sql/rdbmsDataSource.cpp162
-rw-r--r--project2/sql/rdbmsDataSource.h38
-rw-r--r--project2/sql/sqlBulkLoad.cpp10
-rw-r--r--project2/sql/sqlCache.cpp22
-rw-r--r--project2/sql/sqlMergeTask.cpp6
-rw-r--r--project2/sql/sqlMergeTask.h2
-rw-r--r--project2/sql/sqlRows.cpp3
-rw-r--r--project2/sql/sqlTask.cpp3
-rw-r--r--project2/sql/sqlTest.cpp3
9 files changed, 192 insertions, 57 deletions
diff --git a/project2/sql/rdbmsDataSource.cpp b/project2/sql/rdbmsDataSource.cpp
index bf69dcd..7965a29 100644
--- a/project2/sql/rdbmsDataSource.cpp
+++ b/project2/sql/rdbmsDataSource.cpp
@@ -8,6 +8,7 @@
#include <boost/foreach.hpp>
SimpleMessageException(UnknownConnectionProvider);
+#define LOCK(l) std::lock_guard<std::mutex> _lock##l(l)
/// Specialized ElementLoader for instances of RdbmsDataSource; handles persistent DB connections
class RdbmsDataSourceLoader : public ElementLoader::For<RdbmsDataSource> {
@@ -38,7 +39,8 @@ DECLARE_CUSTOM_LOADER("rdbmsdatasource", RdbmsDataSourceLoader);
RdbmsDataSource::DBHosts RdbmsDataSource::dbhosts;
RdbmsDataSource::FailedHosts RdbmsDataSource::failedhosts;
-RdbmsDataSource::DSNSet RdbmsDataSource::changedDSNs;
+RdbmsDataSource::ChangedDSNs RdbmsDataSource::changedDSNs;
+std::mutex RdbmsDataSource::glock;
RdbmsDataSource::RdbmsDataSource(ScriptNodePtr p) :
DataSource(p),
@@ -54,24 +56,31 @@ RdbmsDataSource::~RdbmsDataSource()
{
}
-const DB::Connection &
+RdbmsDataSource::ConnectionRef
RdbmsDataSource::getWritable() const
{
+ LOCK(ilock);
ConnectionPtr master = connectTo(masterDsn);
if (!master->txOpen) {
master->connection->beginTx();
master->txOpen = true;
}
- changedDSNs.insert(name);
- return *master->connection;
+ LOCK(glock);
+ changedDSNs.insert({name, std::this_thread::get_id()});
+ return master.get();
}
-const DB::Connection &
+RdbmsDataSource::ConnectionRef
RdbmsDataSource::getReadonly() const
{
- if (changedDSNs.find(name) != changedDSNs.end()) {
- return *connectTo(masterDsn)->connection;
+ {
+ LOCK(glock);
+ if (changedDSNs.find({name, std::this_thread::get_id()}) != changedDSNs.end()) {
+ glock.unlock();
+ return connectTo(masterDsn).get();
+ }
}
+ LOCK(ilock);
if (localhost.length() == 0 && preferLocal) {
struct utsname name;
if (uname(&name)) {
@@ -89,9 +98,9 @@ RdbmsDataSource::getReadonly() const
if (ro == roDSNs.end()) {
Logger()->messagef(LOG_INFO, "%s: No database host matches local host name (%s) Will use master DSN",
__PRETTY_FUNCTION__, localhost.c_str());
- return *connectTo(masterDsn)->connection;
+ return connectTo(masterDsn).get();
}
- return *connectTo(ro->second)->connection;
+ return connectTo(ro->second).get();
}
catch (...) {
// Failed to connect to a preferred DB... carry on and try the others...
@@ -99,38 +108,55 @@ RdbmsDataSource::getReadonly() const
}
BOOST_FOREACH(ReadonlyDSNs::value_type db, roDSNs) {
try {
- return *connectTo(db.second)->connection;
+ return connectTo(db.second).get();
}
catch (...) {
}
}
- return *connectTo(masterDsn)->connection;
+ return connectTo(masterDsn).get();
}
void
RdbmsDataSource::commit()
{
- DBHosts::const_iterator m = dbhosts.find(masterDsn);
- if (m != dbhosts.end() && m->second->txOpen) {
- m->second->connection->commitTx();
- m->second->txOpen = false;
+ LOCK(ilock);
+ LOCK(glock);
+ auto masters = dbhosts.equal_range(masterDsn);
+ for (auto m = masters.first; m != masters.second; m++) {
+ if (m->second->threadId) {
+ }
+ if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) {
+ m->second->connection->commitTx();
+ m->second->txOpen = false;
+ if (m->second->users == 0) {
+ m->second->threadId.reset();
+ }
+ }
}
}
void
RdbmsDataSource::rollback()
{
- DBHosts::const_iterator m = dbhosts.find(masterDsn);
- if (m != dbhosts.end() && m->second->txOpen) {
- m->second->connection->rollbackTx();
- m->second->txOpen = false;
+ LOCK(ilock);
+ LOCK(glock);
+ auto masters = dbhosts.equal_range(masterDsn);
+ for (auto m = masters.first; m != masters.second; m++) {
+ if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) {
+ m->second->connection->rollbackTx();
+ m->second->txOpen = false;
+ if (m->second->users == 0) {
+ m->second->threadId.reset();
+ }
+ }
}
- changedDSNs.erase(name);
+ changedDSNs.erase({name, std::this_thread::get_id()});
}
RdbmsDataSource::ConnectionPtr
RdbmsDataSource::connectTo(const ConnectionInfo & dsn)
{
+ LOCK(glock);
FailedHosts::iterator dbf = failedhosts.find(dsn);
if (dbf != failedhosts.end()) {
if (time(NULL) - 20 > dbf->second.FailureTime) {
@@ -140,23 +166,27 @@ RdbmsDataSource::connectTo(const ConnectionInfo & dsn)
throw dbf->second;
}
}
- DBHosts::const_iterator dbi = dbhosts.find(dsn);
- if (dbi != dbhosts.end()) {
- try {
- dbi->second->connection->ping();
- dbi->second->touch();
- return dbi->second;
- }
- catch (...) {
- // Connection in failed state
- Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__);
+ auto dbis = dbhosts.equal_range(dsn);
+ for (auto dbi = dbis.first; dbi != dbis.second; dbi++) {
+ if (!dbi->second->threadId || *dbi->second->threadId == std::this_thread::get_id()) {
+ try {
+ dbi->second->connection->ping();
+ dbi->second->threadId = std::this_thread::get_id();
+ dbi->second->touch();
+ return dbi->second;
+ }
+ catch (...) {
+ // Connection in failed state
+ Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__);
+ }
}
}
try {
ConnectionPtr db = ConnectionPtr(new RdbmsConnection(dsn.connect(), 300));
- dbhosts[dsn] = db;
+ db->threadId = std::this_thread::get_id();
db->touch();
+ dbhosts.insert({dsn, db});
return db;
}
catch (const DB::ConnectionError & e) {
@@ -171,6 +201,7 @@ RdbmsDataSource::RdbmsConnection::RdbmsConnection(const DB::Connection * con, ti
connection(con),
txOpen(false),
lastUsedTime(0),
+ users(0),
keepAliveTime(kat)
{
}
@@ -187,10 +218,25 @@ RdbmsDataSource::RdbmsConnection::touch() const
time(&lastUsedTime);
}
+void
+RdbmsDataSource::RdbmsConnection::incRef()
+{
+ users += 1;
+}
+
+void
+RdbmsDataSource::RdbmsConnection::decRef()
+{
+ users -= 1;
+ if (users == 0 && !txOpen) {
+ threadId.reset();
+ }
+}
+
bool
RdbmsDataSource::RdbmsConnection::isExpired() const
{
- return (time(NULL) > lastUsedTime + keepAliveTime);
+ return ((time(NULL) > lastUsedTime + keepAliveTime) && (users == 0));
}
RdbmsDataSource::ConnectionInfo::ConnectionInfo(ScriptNodePtr node) :
@@ -211,3 +257,55 @@ RdbmsDataSource::ConnectionInfo::operator<(const RdbmsDataSource::ConnectionInfo
return ((typeId < other.typeId) || ((typeId == other.typeId) && (dsn < other.dsn)));
}
+RdbmsDataSource::ConnectionRef::ConnectionRef() :
+ conn(NULL)
+{
+}
+
+RdbmsDataSource::ConnectionRef::ConnectionRef(RdbmsConnection * c) :
+ conn(c)
+{
+ if (conn)
+ conn->incRef();
+}
+
+RdbmsDataSource::ConnectionRef::ConnectionRef(const ConnectionRef & ref) :
+ conn(ref.conn)
+{
+ if (conn)
+ conn->incRef();
+}
+
+RdbmsDataSource::ConnectionRef::~ConnectionRef()
+{
+ if (conn)
+ conn->decRef();
+}
+
+void RdbmsDataSource::ConnectionRef::operator=(const ConnectionRef & ref)
+{
+ if (conn)
+ conn->decRef();
+ conn = ref.conn;
+ if (conn)
+ conn->incRef();
+}
+
+const DB::Connection *
+RdbmsDataSource::ConnectionRef::operator->() const
+{
+ return conn->connection;
+}
+
+const DB::Connection &
+RdbmsDataSource::ConnectionRef::operator*() const
+{
+ return *conn->connection;
+}
+
+const DB::Connection *
+RdbmsDataSource::ConnectionRef::get() const
+{
+ return conn->connection;
+}
+
diff --git a/project2/sql/rdbmsDataSource.h b/project2/sql/rdbmsDataSource.h
index 0497aab..0a1ddfd 100644
--- a/project2/sql/rdbmsDataSource.h
+++ b/project2/sql/rdbmsDataSource.h
@@ -4,6 +4,8 @@
#include <boost/shared_ptr.hpp>
#include <map>
#include <set>
+#include <thread>
+#include <mutex>
#include "dataSource.h"
#include <connection.h>
#include <error.h>
@@ -13,6 +15,8 @@
/// Project2 component to provide access to transactional RDBMS data sources
class RdbmsDataSource : public DataSource {
public:
+ class ConnectionRef;
+
class RdbmsConnection {
public:
RdbmsConnection(const DB::Connection * connection, time_t kat);
@@ -22,12 +26,33 @@ class RdbmsDataSource : public DataSource {
bool isExpired() const;
const DB::Connection * const connection;
bool txOpen;
+ boost::optional<std::thread::id> threadId;
private:
+ friend class ConnectionRef;
+ friend class RdbmsDataSource;
mutable time_t lastUsedTime;
+ void incRef();
+ void decRef();
+ unsigned int users;
const time_t keepAliveTime;
};
+ class ConnectionRef {
+ public:
+ ConnectionRef();
+ ConnectionRef(RdbmsConnection *);
+ ConnectionRef(const ConnectionRef &);
+ ~ConnectionRef();
+ void operator=(const ConnectionRef &);
+
+ const DB::Connection * operator->() const;
+ const DB::Connection & operator*() const;
+ const DB::Connection * get() const;
+ private:
+ RdbmsConnection * conn;
+ };
+
class ConnectionInfo {
public:
ConnectionInfo(ScriptNodePtr);
@@ -42,14 +67,14 @@ class RdbmsDataSource : public DataSource {
typedef boost::shared_ptr<RdbmsConnection> ConnectionPtr;
typedef std::map<std::string, ConnectionInfo> ReadonlyDSNs; // Map hostname to DSN string
- typedef std::map<ConnectionInfo, ConnectionPtr> DBHosts; // Map DSN strings to connections
+ typedef std::multimap<ConnectionInfo, ConnectionPtr> DBHosts; // Map DSN strings to connections
typedef std::map<ConnectionInfo, const DB::ConnectionError> FailedHosts; // Map DSN strings to failures
RdbmsDataSource(ScriptNodePtr p);
~RdbmsDataSource();
- const DB::Connection & getReadonly() const;
- const DB::Connection & getWritable() const;
+ ConnectionRef getReadonly() const;
+ ConnectionRef getWritable() const;
virtual void commit();
virtual void rollback();
@@ -61,11 +86,14 @@ class RdbmsDataSource : public DataSource {
ReadonlyDSNs roDSNs;
private:
+ mutable std::mutex ilock;
+ static std::mutex glock;
mutable std::string localhost;
static DBHosts dbhosts;
static FailedHosts failedhosts;
- typedef std::set<std::string> DSNSet;
- static DSNSet changedDSNs;
+ typedef std::pair<std::string, std::thread::id> ChangedDSN;
+ typedef std::set<ChangedDSN> ChangedDSNs;
+ static ChangedDSNs changedDSNs;
friend class RdbmsDataSourceLoader;
};
diff --git a/project2/sql/sqlBulkLoad.cpp b/project2/sql/sqlBulkLoad.cpp
index 8787d3e..813323c 100644
--- a/project2/sql/sqlBulkLoad.cpp
+++ b/project2/sql/sqlBulkLoad.cpp
@@ -25,12 +25,12 @@ class SqlBulkLoad : public Task {
void execute(ExecContext * ec) const
{
- const DB::Connection & wdb = db->getWritable();
- wdb.beginBulkUpload(targetTable(ec), extras(ec));
+ auto wdb = db->getWritable();
+ wdb->beginBulkUpload(targetTable(ec), extras(ec));
ScopeObject tidy([]{},
- [&]{ wdb.endBulkUpload(NULL); },
- [&]{ wdb.endBulkUpload("Stack unwind in progress"); });
- stream->runStream(boost::bind(&DB::Connection::bulkUploadData, &wdb, _1, _2), ec);
+ [&]{ wdb->endBulkUpload(NULL); },
+ [&]{ wdb->endBulkUpload("Stack unwind in progress"); });
+ stream->runStream(boost::bind(&DB::Connection::bulkUploadData, wdb.get(), _1, _2), ec);
}
const Variable dataSource;
diff --git a/project2/sql/sqlCache.cpp b/project2/sql/sqlCache.cpp
index 66e9288..af59ee5 100644
--- a/project2/sql/sqlCache.cpp
+++ b/project2/sql/sqlCache.cpp
@@ -138,7 +138,8 @@ class SqlCache : public Cache {
HeaderTable.c_str(),n.c_str(), f.c_str());
applyKeys(ec, boost::bind(appendKeyAnds, &sql, _1), ps);
sql.appendf(" ORDER BY r.p2_cacheid DESC, r.p2_row");
- SelectPtr gh(db->getReadonly().newSelectCommand(sql));
+ auto con = db->getReadonly();
+ SelectPtr gh(con->newSelectCommand(sql));
unsigned int offset = 0;
gh->bindParamT(offset++, time(NULL) - CacheLife);
applyKeys(ec, boost::bind(bindKeyValues, gh.get(), &offset, _2), ps);
@@ -181,7 +182,8 @@ class SqlCache : public Cache {
sql.append(", ?");
}
sql.appendf(")");
- ModifyPtr m(db->getWritable().newModifyCommand(sql));
+ auto con = db->getWritable();
+ ModifyPtr m(con->newModifyCommand(sql));
unsigned int offset = 0;
m->bindParamI(offset++, row++);
BOOST_FOREACH(const Values::value_type & a, attrs) {
@@ -206,12 +208,13 @@ class SqlCache : public Cache {
{
Buffer sp;
sp.appendf("SAVEPOINT sp%p", this);
- ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp));
+ auto con = db->getWritable();
+ ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));
s->execute();
// Header
Buffer del;
del.appendf("INSERT INTO %s(p2_time) VALUES(?)", HeaderTable.c_str());
- ModifyPtr h = ModifyPtr(db->getWritable().newModifyCommand(del));
+ ModifyPtr h = ModifyPtr(con->newModifyCommand(del));
h->bindParamT(0, time(NULL));
h->execute();
// Record set header
@@ -223,7 +226,7 @@ class SqlCache : public Cache {
offset = 0;
applyKeys(ec, boost::bind(appendKeyBinds, &sql, &offset), ps);
sql.appendf(")");
- ModifyPtr m(db->getWritable().newModifyCommand(sql));
+ ModifyPtr m(con->newModifyCommand(sql));
offset = 0;
applyKeys(ec, boost::bind(bindKeyValues, m.get(), &offset, _2), ps);
m->execute();
@@ -234,7 +237,8 @@ class SqlCache : public Cache {
{
Buffer sp;
sp.appendf("RELEASE SAVEPOINT sp%p", this);
- ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp));
+ auto con = db->getWritable();
+ ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));
s->execute();
}
@@ -242,7 +246,8 @@ class SqlCache : public Cache {
{
Buffer sp;
sp.appendf("ROLLBACK TO SAVEPOINT sp%p", this);
- ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp));
+ auto con = db->getWritable();
+ ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));
s->execute();
}
@@ -268,7 +273,8 @@ class CustomSqlCacheLoader : public ElementLoader::For<SqlCache> {
RdbmsDataSource * db = co->dataSource<RdbmsDataSource>(SqlCache::DataSource);
Buffer del;
del.appendf("DELETE FROM %s WHERE p2_time < ?", SqlCache::HeaderTable.c_str());
- ModifyPtr m(db->getWritable().newModifyCommand(del));
+ auto con = db->getWritable();
+ ModifyPtr m(con->newModifyCommand(del));
m->bindParamT(0, time(NULL) - SqlCache::CacheLife);
m->execute();
db->commit();
diff --git a/project2/sql/sqlMergeTask.cpp b/project2/sql/sqlMergeTask.cpp
index 4356279..880a85e 100644
--- a/project2/sql/sqlMergeTask.cpp
+++ b/project2/sql/sqlMergeTask.cpp
@@ -118,7 +118,7 @@ SqlMergeTask::~SqlMergeTask()
void
SqlMergeTask::loadComplete(const CommonObjects * co)
{
- destdb = &co->dataSource<RdbmsDataSource>(dataSource(NULL))->getWritable();
+ destdb = co->dataSource<RdbmsDataSource>(dataSource(NULL))->getWritable();
insCmd = insertCommand();
BOOST_FOREACH(const Sources::value_type & i, sources) {
attach(i, insCmd);
@@ -165,8 +165,8 @@ SqlMergeTask::execute(ExecContext * ec) const
auto savepoint(stringf("sqlmerge_savepoint_%p", this));
destdb->savepoint(savepoint);
ScopeObject SPHandler(NULL,
- boost::bind(&DB::Connection::releaseSavepoint, destdb, savepoint),
- boost::bind(&DB::Connection::rollbackToSavepoint, destdb, savepoint));
+ boost::bind(&DB::Connection::releaseSavepoint, destdb.get(), savepoint),
+ boost::bind(&DB::Connection::rollbackToSavepoint, destdb.get(), savepoint));
createTempTable();
if (earlyKeys(NULL)) {
createTempKey();
diff --git a/project2/sql/sqlMergeTask.h b/project2/sql/sqlMergeTask.h
index 7fe385a..34fd7aa 100644
--- a/project2/sql/sqlMergeTask.h
+++ b/project2/sql/sqlMergeTask.h
@@ -72,7 +72,7 @@ class SqlMergeTask : public Task {
public:
Sources sources;
- const DB::Connection * destdb;
+ RdbmsDataSource::ConnectionRef destdb;
const Variable dataSource;
const Table dtable;
const Table dtablet;
diff --git a/project2/sql/sqlRows.cpp b/project2/sql/sqlRows.cpp
index 6570a60..1395db9 100644
--- a/project2/sql/sqlRows.cpp
+++ b/project2/sql/sqlRows.cpp
@@ -51,7 +51,8 @@ void
SqlRows::execute(const Glib::ustring & filter, const RowProcessorCallback & rp, ExecContext * ec) const
{
unsigned int offset = 0;
- auto select = SelectPtr(db->getReadonly().newSelectCommand(sqlCommand.getSqlFor(filter)));
+ auto con = db->getReadonly();
+ auto select = SelectPtr(con->newSelectCommand(sqlCommand.getSqlFor(filter)));
sqlCommand.bindParams(ec, select.get(), offset);
SqlState ss(select);
while (ss.query->fetch()) {
diff --git a/project2/sql/sqlTask.cpp b/project2/sql/sqlTask.cpp
index 80ea08e..eab00d5 100644
--- a/project2/sql/sqlTask.cpp
+++ b/project2/sql/sqlTask.cpp
@@ -35,8 +35,9 @@ SqlTask::loadComplete(const CommonObjects * co)
void
SqlTask::execute(ExecContext * ec) const
{
+ auto con = db->getWritable();
boost::shared_ptr<DB::ModifyCommand> modify = boost::shared_ptr<DB::ModifyCommand>(
- db->getWritable().newModifyCommand(sqlCommand.getSqlFor(filter(NULL))));
+ con->newModifyCommand(sqlCommand.getSqlFor(filter(NULL))));
unsigned int offset = 0;
sqlCommand.bindParams(ec, modify.get(), offset);
if (modify->execute() == 0) {
diff --git a/project2/sql/sqlTest.cpp b/project2/sql/sqlTest.cpp
index 2723d8e..45178ef 100644
--- a/project2/sql/sqlTest.cpp
+++ b/project2/sql/sqlTest.cpp
@@ -90,8 +90,9 @@ class HandleDoCompare : public DB::HandleField {
bool
SqlTest::passes(ExecContext * ec) const
{
+ auto con = db->getReadonly();
boost::shared_ptr<DB::SelectCommand> query = boost::shared_ptr<DB::SelectCommand>(
- db->getWritable().newSelectCommand(sqlCommand.getSqlFor(filter(NULL))));
+ con->newSelectCommand(sqlCommand.getSqlFor(filter(NULL))));
unsigned int offset = 0;
sqlCommand.bindParams(ec, query.get(), offset);
HandleDoCompare h(testValue(ec), testOp(ec));