diff options
| -rw-r--r-- | project2/sql/rdbmsDataSource.cpp | 162 | ||||
| -rw-r--r-- | project2/sql/rdbmsDataSource.h | 38 | ||||
| -rw-r--r-- | project2/sql/sqlBulkLoad.cpp | 10 | ||||
| -rw-r--r-- | project2/sql/sqlCache.cpp | 22 | ||||
| -rw-r--r-- | project2/sql/sqlMergeTask.cpp | 6 | ||||
| -rw-r--r-- | project2/sql/sqlMergeTask.h | 2 | ||||
| -rw-r--r-- | project2/sql/sqlRows.cpp | 3 | ||||
| -rw-r--r-- | project2/sql/sqlTask.cpp | 3 | ||||
| -rw-r--r-- | project2/sql/sqlTest.cpp | 3 | 
9 files changed, 192 insertions, 57 deletions
diff --git a/project2/sql/rdbmsDataSource.cpp b/project2/sql/rdbmsDataSource.cpp index bf69dcd..7965a29 100644 --- a/project2/sql/rdbmsDataSource.cpp +++ b/project2/sql/rdbmsDataSource.cpp @@ -8,6 +8,7 @@  #include <boost/foreach.hpp>  SimpleMessageException(UnknownConnectionProvider); +#define LOCK(l) std::lock_guard<std::mutex> _lock##l(l)  /// Specialized ElementLoader for instances of RdbmsDataSource; handles persistent DB connections  class RdbmsDataSourceLoader : public ElementLoader::For<RdbmsDataSource> { @@ -38,7 +39,8 @@ DECLARE_CUSTOM_LOADER("rdbmsdatasource", RdbmsDataSourceLoader);  RdbmsDataSource::DBHosts RdbmsDataSource::dbhosts;  RdbmsDataSource::FailedHosts RdbmsDataSource::failedhosts; -RdbmsDataSource::DSNSet RdbmsDataSource::changedDSNs; +RdbmsDataSource::ChangedDSNs RdbmsDataSource::changedDSNs; +std::mutex RdbmsDataSource::glock;  RdbmsDataSource::RdbmsDataSource(ScriptNodePtr p) :  	DataSource(p), @@ -54,24 +56,31 @@ RdbmsDataSource::~RdbmsDataSource()  {  } -const DB::Connection & +RdbmsDataSource::ConnectionRef  RdbmsDataSource::getWritable() const  { +	LOCK(ilock);  	ConnectionPtr master = connectTo(masterDsn);  	if (!master->txOpen) {  		master->connection->beginTx();  		master->txOpen = true;  	} -	changedDSNs.insert(name); -	return *master->connection; +	LOCK(glock); +	changedDSNs.insert({name, std::this_thread::get_id()}); +	return master.get();  } -const DB::Connection & +RdbmsDataSource::ConnectionRef  RdbmsDataSource::getReadonly() const  { -	if (changedDSNs.find(name) != changedDSNs.end()) { -		return *connectTo(masterDsn)->connection; +	{ +		LOCK(glock); +		if (changedDSNs.find({name, std::this_thread::get_id()}) != changedDSNs.end()) { +			glock.unlock(); +			return connectTo(masterDsn).get(); +		}  	} +	LOCK(ilock);  	if (localhost.length() == 0 && preferLocal) {  		struct utsname name;  		if (uname(&name)) { @@ -89,9 +98,9 @@ RdbmsDataSource::getReadonly() const  			if (ro == roDSNs.end()) {  				Logger()->messagef(LOG_INFO, "%s: No database host matches local host name (%s) Will use master DSN",  						__PRETTY_FUNCTION__, localhost.c_str()); -				return *connectTo(masterDsn)->connection; +				return connectTo(masterDsn).get();  			} -			return *connectTo(ro->second)->connection; +			return connectTo(ro->second).get();  		}  		catch (...) {  			// Failed to connect to a preferred DB... carry on and try the others... @@ -99,38 +108,55 @@ RdbmsDataSource::getReadonly() const  	}  	BOOST_FOREACH(ReadonlyDSNs::value_type db, roDSNs) {  		try { -			return *connectTo(db.second)->connection; +			return connectTo(db.second).get();  		}  		catch (...) {  		}  	} -	return *connectTo(masterDsn)->connection; +	return connectTo(masterDsn).get();  }  void  RdbmsDataSource::commit()  { -	DBHosts::const_iterator m = dbhosts.find(masterDsn); -	if (m != dbhosts.end() && m->second->txOpen) { -		m->second->connection->commitTx(); -		m->second->txOpen = false; +	LOCK(ilock); +	LOCK(glock); +	auto masters = dbhosts.equal_range(masterDsn); +	for (auto m = masters.first; m != masters.second; m++) { +		if (m->second->threadId) { +		} +		if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) { +			m->second->connection->commitTx(); +			m->second->txOpen = false; +			if (m->second->users == 0) { +				m->second->threadId.reset(); +			} +		}  	}  }  void  RdbmsDataSource::rollback()  { -	DBHosts::const_iterator m = dbhosts.find(masterDsn); -	if (m != dbhosts.end() && m->second->txOpen) { -		m->second->connection->rollbackTx(); -		m->second->txOpen = false; +	LOCK(ilock); +	LOCK(glock); +	auto masters = dbhosts.equal_range(masterDsn); +	for (auto m = masters.first; m != masters.second; m++) { +		if (m->second->txOpen && m->second->threadId && *m->second->threadId == std::this_thread::get_id()) { +			m->second->connection->rollbackTx(); +			m->second->txOpen = false; +			if (m->second->users == 0) { +				m->second->threadId.reset(); +			} +		}  	} -	changedDSNs.erase(name); +	changedDSNs.erase({name, std::this_thread::get_id()});  }  RdbmsDataSource::ConnectionPtr  RdbmsDataSource::connectTo(const ConnectionInfo & dsn)  { +	LOCK(glock);  	FailedHosts::iterator dbf = failedhosts.find(dsn);  	if (dbf != failedhosts.end()) {  		if (time(NULL) - 20 > dbf->second.FailureTime) { @@ -140,23 +166,27 @@ RdbmsDataSource::connectTo(const ConnectionInfo & dsn)  			throw dbf->second;  		}  	} -	DBHosts::const_iterator dbi = dbhosts.find(dsn); -	if (dbi != dbhosts.end()) { -		try { -			dbi->second->connection->ping(); -			dbi->second->touch(); -			return dbi->second; -		} -		catch (...) { -			// Connection in failed state -			Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__); +	auto dbis = dbhosts.equal_range(dsn); +	for (auto dbi = dbis.first; dbi != dbis.second; dbi++) { +		if (!dbi->second->threadId || *dbi->second->threadId == std::this_thread::get_id()) { +			try { +				dbi->second->connection->ping(); +				dbi->second->threadId = std::this_thread::get_id(); +				dbi->second->touch(); +				return dbi->second; +			} +			catch (...) { +				// Connection in failed state +				Logger()->messagef(LOG_DEBUG, "%s: Cached connection failed", __PRETTY_FUNCTION__); +			}  		}  	}  	try {  		ConnectionPtr db = ConnectionPtr(new RdbmsConnection(dsn.connect(), 300)); -		dbhosts[dsn] = db; +		db->threadId = std::this_thread::get_id();  		db->touch(); +		dbhosts.insert({dsn, db});  		return db;  	}  	catch (const DB::ConnectionError & e) { @@ -171,6 +201,7 @@ RdbmsDataSource::RdbmsConnection::RdbmsConnection(const DB::Connection * con, ti  	connection(con),  	txOpen(false),  	lastUsedTime(0), +	users(0),  	keepAliveTime(kat)  {  } @@ -187,10 +218,25 @@ RdbmsDataSource::RdbmsConnection::touch() const  	time(&lastUsedTime);  } +void +RdbmsDataSource::RdbmsConnection::incRef() +{ +	users += 1; +} + +void +RdbmsDataSource::RdbmsConnection::decRef() +{ +	users -= 1; +	if (users == 0 && !txOpen) { +		threadId.reset(); +	} +} +  bool  RdbmsDataSource::RdbmsConnection::isExpired() const  { -	return (time(NULL) > lastUsedTime + keepAliveTime); +	return ((time(NULL) > lastUsedTime + keepAliveTime) && (users == 0));  }  RdbmsDataSource::ConnectionInfo::ConnectionInfo(ScriptNodePtr node) : @@ -211,3 +257,55 @@ RdbmsDataSource::ConnectionInfo::operator<(const RdbmsDataSource::ConnectionInfo  	return ((typeId < other.typeId) || ((typeId == other.typeId) && (dsn < other.dsn)));  } +RdbmsDataSource::ConnectionRef::ConnectionRef() : +	conn(NULL) +{ +} + +RdbmsDataSource::ConnectionRef::ConnectionRef(RdbmsConnection * c) : +	conn(c) +{ +	if (conn) +		conn->incRef(); +} + +RdbmsDataSource::ConnectionRef::ConnectionRef(const ConnectionRef & ref) : +	conn(ref.conn) +{ +	if (conn) +		conn->incRef(); +} + +RdbmsDataSource::ConnectionRef::~ConnectionRef() +{ +	if (conn) +		conn->decRef(); +} + +void RdbmsDataSource::ConnectionRef::operator=(const ConnectionRef & ref) +{ +	if (conn) +		conn->decRef(); +	conn = ref.conn; +	if (conn) +		conn->incRef(); +} + +const DB::Connection * +RdbmsDataSource::ConnectionRef::operator->() const +{ +	return conn->connection; +} + +const DB::Connection & +RdbmsDataSource::ConnectionRef::operator*() const +{ +	return *conn->connection; +} + +const DB::Connection * +RdbmsDataSource::ConnectionRef::get() const +{ +	return conn->connection; +} + diff --git a/project2/sql/rdbmsDataSource.h b/project2/sql/rdbmsDataSource.h index 0497aab..0a1ddfd 100644 --- a/project2/sql/rdbmsDataSource.h +++ b/project2/sql/rdbmsDataSource.h @@ -4,6 +4,8 @@  #include <boost/shared_ptr.hpp>  #include <map>  #include <set> +#include <thread> +#include <mutex>  #include "dataSource.h"  #include <connection.h>  #include <error.h> @@ -13,6 +15,8 @@  /// Project2 component to provide access to transactional RDBMS data sources  class RdbmsDataSource : public DataSource {  	public: +		class ConnectionRef; +  		class RdbmsConnection {  			public:  				RdbmsConnection(const DB::Connection * connection, time_t kat); @@ -22,12 +26,33 @@ class RdbmsDataSource : public DataSource {  				bool isExpired() const;  				const DB::Connection * const connection;  				bool txOpen; +				boost::optional<std::thread::id> threadId;  			private: +				friend class ConnectionRef; +				friend class RdbmsDataSource;  				mutable time_t lastUsedTime; +				void incRef(); +				void decRef(); +				unsigned int users;  				const time_t keepAliveTime;  		}; +		class ConnectionRef { +			public: +				ConnectionRef(); +				ConnectionRef(RdbmsConnection *); +				ConnectionRef(const ConnectionRef &); +				~ConnectionRef(); +				void operator=(const ConnectionRef &); + +				const DB::Connection * operator->() const; +				const DB::Connection & operator*() const; +				const DB::Connection * get() const; +			private: +				RdbmsConnection * conn; +		}; +  		class ConnectionInfo {  			public:  				ConnectionInfo(ScriptNodePtr); @@ -42,14 +67,14 @@ class RdbmsDataSource : public DataSource {  		typedef boost::shared_ptr<RdbmsConnection> ConnectionPtr;  		typedef std::map<std::string, ConnectionInfo> ReadonlyDSNs; // Map hostname to DSN string -		typedef std::map<ConnectionInfo, ConnectionPtr> DBHosts; // Map DSN strings to connections +		typedef std::multimap<ConnectionInfo, ConnectionPtr> DBHosts; // Map DSN strings to connections  		typedef std::map<ConnectionInfo, const DB::ConnectionError> FailedHosts; // Map DSN strings to failures  		RdbmsDataSource(ScriptNodePtr p);  		~RdbmsDataSource(); -		const DB::Connection & getReadonly() const; -		const DB::Connection & getWritable() const; +		ConnectionRef getReadonly() const; +		ConnectionRef getWritable() const;  		virtual void commit();  		virtual void rollback(); @@ -61,11 +86,14 @@ class RdbmsDataSource : public DataSource {  		ReadonlyDSNs roDSNs;  	private: +		mutable std::mutex ilock; +		static std::mutex glock;  		mutable std::string localhost;  		static DBHosts dbhosts;  		static FailedHosts failedhosts; -		typedef std::set<std::string> DSNSet; -		static DSNSet changedDSNs; +		typedef std::pair<std::string, std::thread::id> ChangedDSN; +		typedef std::set<ChangedDSN> ChangedDSNs; +		static ChangedDSNs changedDSNs;  		friend class RdbmsDataSourceLoader;  }; diff --git a/project2/sql/sqlBulkLoad.cpp b/project2/sql/sqlBulkLoad.cpp index 8787d3e..813323c 100644 --- a/project2/sql/sqlBulkLoad.cpp +++ b/project2/sql/sqlBulkLoad.cpp @@ -25,12 +25,12 @@ class SqlBulkLoad : public Task {  		void execute(ExecContext * ec) const  		{ -			const DB::Connection & wdb = db->getWritable(); -			wdb.beginBulkUpload(targetTable(ec), extras(ec)); +			auto wdb = db->getWritable(); +			wdb->beginBulkUpload(targetTable(ec), extras(ec));  			ScopeObject tidy([]{}, -					[&]{ wdb.endBulkUpload(NULL); }, -					[&]{ wdb.endBulkUpload("Stack unwind in progress"); }); -			stream->runStream(boost::bind(&DB::Connection::bulkUploadData, &wdb, _1, _2), ec); +					[&]{ wdb->endBulkUpload(NULL); }, +					[&]{ wdb->endBulkUpload("Stack unwind in progress"); }); +			stream->runStream(boost::bind(&DB::Connection::bulkUploadData, wdb.get(), _1, _2), ec);  		}  		const Variable dataSource; diff --git a/project2/sql/sqlCache.cpp b/project2/sql/sqlCache.cpp index 66e9288..af59ee5 100644 --- a/project2/sql/sqlCache.cpp +++ b/project2/sql/sqlCache.cpp @@ -138,7 +138,8 @@ class SqlCache : public Cache {  					HeaderTable.c_str(),n.c_str(), f.c_str());  			applyKeys(ec, boost::bind(appendKeyAnds, &sql, _1), ps);  			sql.appendf(" ORDER BY r.p2_cacheid DESC, r.p2_row"); -			SelectPtr gh(db->getReadonly().newSelectCommand(sql)); +			auto con = db->getReadonly(); +			SelectPtr gh(con->newSelectCommand(sql));  			unsigned int offset = 0;  			gh->bindParamT(offset++, time(NULL) - CacheLife);  			applyKeys(ec, boost::bind(bindKeyValues, gh.get(), &offset, _2), ps); @@ -181,7 +182,8 @@ class SqlCache : public Cache {  						sql.append(", ?");  					}  					sql.appendf(")"); -					ModifyPtr m(db->getWritable().newModifyCommand(sql)); +					auto con = db->getWritable(); +					ModifyPtr m(con->newModifyCommand(sql));  					unsigned int offset = 0;  					m->bindParamI(offset++, row++);  					BOOST_FOREACH(const Values::value_type & a, attrs) { @@ -206,12 +208,13 @@ class SqlCache : public Cache {  		{  			Buffer sp;  			sp.appendf("SAVEPOINT sp%p", this); -			ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); +			auto con = db->getWritable(); +			ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));  			s->execute();  			// Header  			Buffer del;  			del.appendf("INSERT INTO %s(p2_time) VALUES(?)", HeaderTable.c_str()); -			ModifyPtr h = ModifyPtr(db->getWritable().newModifyCommand(del)); +			ModifyPtr h = ModifyPtr(con->newModifyCommand(del));  			h->bindParamT(0, time(NULL));  			h->execute();  			// Record set header @@ -223,7 +226,7 @@ class SqlCache : public Cache {  			offset = 0;  			applyKeys(ec, boost::bind(appendKeyBinds, &sql, &offset), ps);  			sql.appendf(")"); -			ModifyPtr m(db->getWritable().newModifyCommand(sql)); +			ModifyPtr m(con->newModifyCommand(sql));  			offset = 0;  			applyKeys(ec, boost::bind(bindKeyValues, m.get(), &offset, _2), ps);  			m->execute(); @@ -234,7 +237,8 @@ class SqlCache : public Cache {  		{  			Buffer sp;  			sp.appendf("RELEASE SAVEPOINT sp%p", this); -			ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); +			auto con = db->getWritable(); +			ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));  			s->execute();  		} @@ -242,7 +246,8 @@ class SqlCache : public Cache {  		{  			Buffer sp;  			sp.appendf("ROLLBACK TO SAVEPOINT sp%p", this); -			ModifyPtr s = ModifyPtr(db->getWritable().newModifyCommand(sp)); +			auto con = db->getWritable(); +			ModifyPtr s = ModifyPtr(con->newModifyCommand(sp));  			s->execute();  		} @@ -268,7 +273,8 @@ class CustomSqlCacheLoader : public ElementLoader::For<SqlCache> {  				RdbmsDataSource * db = co->dataSource<RdbmsDataSource>(SqlCache::DataSource);  				Buffer del;  				del.appendf("DELETE FROM %s WHERE p2_time < ?", SqlCache::HeaderTable.c_str()); -				ModifyPtr m(db->getWritable().newModifyCommand(del)); +				auto con = db->getWritable(); +				ModifyPtr m(con->newModifyCommand(del));  				m->bindParamT(0, time(NULL) - SqlCache::CacheLife);  				m->execute();  				db->commit(); diff --git a/project2/sql/sqlMergeTask.cpp b/project2/sql/sqlMergeTask.cpp index 4356279..880a85e 100644 --- a/project2/sql/sqlMergeTask.cpp +++ b/project2/sql/sqlMergeTask.cpp @@ -118,7 +118,7 @@ SqlMergeTask::~SqlMergeTask()  void  SqlMergeTask::loadComplete(const CommonObjects * co)  { -	destdb = &co->dataSource<RdbmsDataSource>(dataSource(NULL))->getWritable(); +	destdb = co->dataSource<RdbmsDataSource>(dataSource(NULL))->getWritable();  	insCmd = insertCommand();  	BOOST_FOREACH(const Sources::value_type & i, sources) {  		attach(i, insCmd); @@ -165,8 +165,8 @@ SqlMergeTask::execute(ExecContext * ec) const  	auto savepoint(stringf("sqlmerge_savepoint_%p", this));  	destdb->savepoint(savepoint);  	ScopeObject SPHandler(NULL, -			boost::bind(&DB::Connection::releaseSavepoint, destdb, savepoint), -			boost::bind(&DB::Connection::rollbackToSavepoint, destdb, savepoint)); +			boost::bind(&DB::Connection::releaseSavepoint, destdb.get(), savepoint), +			boost::bind(&DB::Connection::rollbackToSavepoint, destdb.get(), savepoint));  	createTempTable();  	if (earlyKeys(NULL)) {  		createTempKey(); diff --git a/project2/sql/sqlMergeTask.h b/project2/sql/sqlMergeTask.h index 7fe385a..34fd7aa 100644 --- a/project2/sql/sqlMergeTask.h +++ b/project2/sql/sqlMergeTask.h @@ -72,7 +72,7 @@ class SqlMergeTask : public Task {  	public:  		Sources sources; -		const DB::Connection * destdb; +		RdbmsDataSource::ConnectionRef destdb;  		const Variable dataSource;  		const Table dtable;  		const Table dtablet; diff --git a/project2/sql/sqlRows.cpp b/project2/sql/sqlRows.cpp index 6570a60..1395db9 100644 --- a/project2/sql/sqlRows.cpp +++ b/project2/sql/sqlRows.cpp @@ -51,7 +51,8 @@ void  SqlRows::execute(const Glib::ustring & filter, const RowProcessorCallback & rp, ExecContext * ec) const  {  	unsigned int offset = 0; -	auto select = SelectPtr(db->getReadonly().newSelectCommand(sqlCommand.getSqlFor(filter))); +	auto con = db->getReadonly(); +	auto select = SelectPtr(con->newSelectCommand(sqlCommand.getSqlFor(filter)));  	sqlCommand.bindParams(ec, select.get(), offset);  	SqlState ss(select);  	while (ss.query->fetch()) { diff --git a/project2/sql/sqlTask.cpp b/project2/sql/sqlTask.cpp index 80ea08e..eab00d5 100644 --- a/project2/sql/sqlTask.cpp +++ b/project2/sql/sqlTask.cpp @@ -35,8 +35,9 @@ SqlTask::loadComplete(const CommonObjects * co)  void  SqlTask::execute(ExecContext * ec) const  { +	auto con = db->getWritable();  	boost::shared_ptr<DB::ModifyCommand> modify = boost::shared_ptr<DB::ModifyCommand>( -			db->getWritable().newModifyCommand(sqlCommand.getSqlFor(filter(NULL)))); +			con->newModifyCommand(sqlCommand.getSqlFor(filter(NULL))));  	unsigned int offset = 0;  	sqlCommand.bindParams(ec, modify.get(), offset);  	if (modify->execute() == 0) { diff --git a/project2/sql/sqlTest.cpp b/project2/sql/sqlTest.cpp index 2723d8e..45178ef 100644 --- a/project2/sql/sqlTest.cpp +++ b/project2/sql/sqlTest.cpp @@ -90,8 +90,9 @@ class HandleDoCompare : public DB::HandleField {  bool  SqlTest::passes(ExecContext * ec) const  { +	auto con = db->getReadonly();  	boost::shared_ptr<DB::SelectCommand> query = boost::shared_ptr<DB::SelectCommand>( -			db->getWritable().newSelectCommand(sqlCommand.getSqlFor(filter(NULL)))); +			con->newSelectCommand(sqlCommand.getSqlFor(filter(NULL))));  	unsigned int offset = 0;  	sqlCommand.bindParams(ec, query.get(), offset);  	HandleDoCompare h(testValue(ec), testOp(ec));  | 
