From fcdf038ce2cf1fd9d18386b1564db75926e9dae3 Mon Sep 17 00:00:00 2001
From: Dan Goodliffe <dan@randomdan.homeip.net>
Date: Tue, 5 Jan 2016 21:29:23 +0000
Subject: Have cache elements keep a const shared_ptr<const T> for items and
 return a shared pointer to clients (minor interface change, fixes usage of
 pointer after removal race condition)

---
 libadhocutil/cache.h                 | 51 ++++++++++++++++++++++++------
 libadhocutil/cache.impl.h            | 60 ++++++++++++++++++++++++++++--------
 libadhocutil/unittests/testCache.cpp | 42 +++++++++++++++++++++++--
 3 files changed, 129 insertions(+), 24 deletions(-)

diff --git a/libadhocutil/cache.h b/libadhocutil/cache.h
index a394e92..4e66e55 100644
--- a/libadhocutil/cache.h
+++ b/libadhocutil/cache.h
@@ -17,35 +17,50 @@ namespace AdHoc {
 template <typename T, typename K>
 class DLL_PUBLIC Cacheable {
 	public:
+		typedef const boost::shared_ptr<const T> Value;
 		Cacheable(const K & k, time_t validUntil);
 
 		const K key;
 		const time_t validUntil;
 
-		virtual const T & item() const = 0;
+		virtual Value item() const = 0;
 };
 
 template <typename T, typename K>
 class DLL_PUBLIC ObjectCacheable : public Cacheable<T, K> {
 	public:
 		ObjectCacheable(const T & t, const K & k, time_t validUtil);
+		ObjectCacheable(typename Cacheable<T, K>::Value & t, const K & k, time_t validUtil);
 
-		virtual const T & item() const override;
+		virtual typename Cacheable<T, K>::Value item() const override;
 
 	private:
-		const T value;
+		typename Cacheable<T, K>::Value value;
 };
 
 template <typename T, typename K>
 class DLL_PUBLIC CallCacheable : public Cacheable<T, K> {
 	public:
-		CallCacheable(const T & t, const K & k, time_t validUtil);
-		CallCacheable(const boost::function<T()> & t, const K & k, time_t validUtil);
+		typedef boost::function<T()> Factory;
+		CallCacheable(const Factory & t, const K & k, time_t validUtil);
 
-		virtual const T & item() const override;
+		virtual typename Cacheable<T, K>::Value item() const override;
 
 	private:
-		mutable boost::variant<T, boost::function<T()>> value;
+		mutable boost::variant<boost::shared_ptr<const T>, Factory> value;
+		mutable boost::shared_mutex lock;
+};
+
+template <typename T, typename K>
+class DLL_PUBLIC PointerCallCacheable : public Cacheable<T, K> {
+	public:
+		typedef boost::function<typename Cacheable<T, K>::Value()> Factory;
+		PointerCallCacheable(const Factory & t, const K & k, time_t validUtil);
+
+		virtual typename Cacheable<T, K>::Value item() const override;
+
+	private:
+		mutable boost::variant<boost::shared_ptr<const T>, Factory> value;
 		mutable boost::shared_mutex lock;
 };
 
@@ -59,7 +74,9 @@ class DLL_PUBLIC Cache {
 	public:
 		/// @cond
 		typedef K Key;
-		typedef T Value;
+		typedef const boost::shared_ptr<const T> Value;
+		typedef boost::function<T()> Factory;
+		typedef boost::function<Value()> PointerFactory;
 		typedef Cacheable<T, K> Item;
 		typedef boost::shared_ptr<Item> Element;
 		/// @endcond
@@ -73,6 +90,12 @@ class DLL_PUBLIC Cache {
 		 * @param validUntil The absolute time the cache item should expire.
 		 */
 		void add(const K & k, const T & t, time_t validUntil);
+		/** Add a known item to the cache.
+		 * @param k The key of the cache item.
+		 * @param t The item to cache.
+		 * @param validUntil The absolute time the cache item should expire.
+		 */
+		void add(const K & k, Value & t, time_t validUntil);
 		/** Add a callback item to the cache.
 		 * The callback will be called on first hit of the cache item, at which
 		 * point the return value of the function will be cached.
@@ -80,14 +103,22 @@ class DLL_PUBLIC Cache {
 		 * @param tf The callback function to cache.
 		 * @param validUntil The absolute time the cache item should expire.
 		 */
-		void add(const K & k, const boost::function<T()> & tf, time_t validUntil);
+		void add(const K & k, const Factory & tf, time_t validUntil);
+		/** Add a pointer callback item to the cache.
+		 * The callback will be called on first hit of the cache item, at which
+		 * point the return value of the function will be cached.
+		 * @param k The key of the cache item.
+		 * @param tf The callback function to cache.
+		 * @param validUntil The absolute time the cache item should expire.
+		 */
+		void add(const K & k, const PointerFactory & tf, time_t validUntil);
 		/** Get an Element from the cache. The element represents the key, item and expiry time.
 		 * Returns null on cache-miss.
 		 * @param k Cache key to get. */
 		Element getItem(const K & k) const;
 		/** Get an Item from the cache. Returns null on cache-miss.
 		 * @param k Cache key to get. */
-		const T * get(const K & k) const;
+		Value get(const K & k) const;
 		/** Get the size of the cache (number of items). @warning This cannot be reliably used to
 		 * determine or estimate the amount of memory used by items in the cache without further
 		 * knowledge of the items themselves. */
diff --git a/libadhocutil/cache.impl.h b/libadhocutil/cache.impl.h
index 0a2b0b3..dffed7f 100644
--- a/libadhocutil/cache.impl.h
+++ b/libadhocutil/cache.impl.h
@@ -17,44 +17,64 @@ Cacheable<T, K>::Cacheable(const K & k, time_t vu) :
 
 template<typename T, typename K>
 ObjectCacheable<T, K>::ObjectCacheable(const T & t, const K & k, time_t vu) :
+	Cacheable<T, K>(k, vu),
+	value(new T(t))
+{
+}
+
+template<typename T, typename K>
+ObjectCacheable<T, K>::ObjectCacheable(typename Cacheable<T, K>::Value & t, const K & k, time_t vu) :
 	Cacheable<T, K>(k, vu),
 	value(t)
 {
 }
 
 template<typename T, typename K>
-const T &
+typename Cacheable<T, K>::Value
 ObjectCacheable<T, K>::item() const
 {
 	return value;
 }
 
 template<typename T, typename K>
-CallCacheable<T, K>::CallCacheable(const T & t, const K & k, time_t vu) :
+CallCacheable<T, K>::CallCacheable(const Factory & t, const K & k, time_t vu) :
 	Cacheable<T, K>(k, vu),
 	value(t)
 {
 }
 
 template<typename T, typename K>
-CallCacheable<T, K>::CallCacheable(const boost::function<T()> & t, const K & k, time_t vu) :
+typename Cacheable<T, K>::Value
+CallCacheable<T, K>::item() const
+{
+	Lock(lock);
+	if (auto t = boost::get<typename Cacheable<T, K>::Value>(&value)) {
+		return *t;
+	}
+	const Factory & f = boost::get<Factory>(value);
+	value = typename Cacheable<T, K>::Value(new T(f()));
+	return boost::get<typename Cacheable<T, K>::Value>(value);
+}
+
+
+template<typename T, typename K>
+PointerCallCacheable<T, K>::PointerCallCacheable(const Factory & t, const K & k, time_t vu) :
 	Cacheable<T, K>(k, vu),
 	value(t)
 {
 }
 
 template<typename T, typename K>
-const T &
-CallCacheable<T, K>::item() const
+typename Cacheable<T, K>::Value
+PointerCallCacheable<T, K>::item() const
 {
 	Lock(lock);
-	const T * t = boost::get<T>(&value);
-	if (t) {
+	if (auto t = boost::get<typename Cacheable<T, K>::Value>(&value)) {
 		return *t;
 	}
-	const boost::function<T()> & f = boost::get<boost::function<T()>>(value);
+	const Factory & f = boost::get<Factory>(value);
 	value = f();
-	return boost::get<T>(value);
+	return boost::get<typename Cacheable<T, K>::Value>(value);
 }
 
 
@@ -74,12 +94,28 @@ Cache<T, K>::add(const K & k, const T & t, time_t validUntil)
 
 template<typename T, typename K>
 void
-Cache<T, K>::add(const K & k, const boost::function<T()> & tf, time_t validUntil)
+Cache<T, K>::add(const K & k, Value & t, time_t validUntil)
+{
+	Lock(lock);
+	cached.insert(Element(new ObjectCacheable<T, K>(t, k, validUntil)));
+}
+
+template<typename T, typename K>
+void
+Cache<T, K>::add(const K & k, const Factory & tf, time_t validUntil)
 {
 	Lock(lock);
 	cached.insert(Element(new CallCacheable<T, K>(tf, k, validUntil)));
 }
 
+template<typename T, typename K>
+void
+Cache<T, K>::add(const K & k, const PointerFactory & tf, time_t validUntil)
+{
+	Lock(lock);
+	cached.insert(Element(new PointerCallCacheable<T, K>(tf, k, validUntil)));
+}
+
 template<typename T, typename K>
 typename Cache<T, K>::Element
 Cache<T, K>::getItem(const K & k) const
@@ -100,12 +136,12 @@ Cache<T, K>::getItem(const K & k) const
 }
 
 template<typename T, typename K>
-const T *
+typename Cache<T, K>::Value
 Cache<T, K>::get(const K & k) const
 {
 	auto i = getItem(k);
 	if (i) {
-		return &i->item();
+		return i->item();
 	}
 	return nullptr;
 }
diff --git a/libadhocutil/unittests/testCache.cpp b/libadhocutil/unittests/testCache.cpp
index 914d282..89ae556 100644
--- a/libadhocutil/unittests/testCache.cpp
+++ b/libadhocutil/unittests/testCache.cpp
@@ -12,6 +12,7 @@ BOOST_TEST_DONT_PRINT_LOG_VALUE(std::nullptr_t);
 class Obj {
 	public:
 		Obj(int i) : v(i) { }
+		void operator=(const Obj &) = delete;
 		bool operator==(const int & i) const {
 			return v == i;
 		}
@@ -37,6 +38,7 @@ namespace AdHoc {
 	template class Cacheable<Obj, std::string>;
 	template class ObjectCacheable<Obj, std::string>;
 	template class CallCacheable<Obj, std::string>;
+	template class PointerCallCacheable<Obj, std::string>;
 }
 
 using namespace AdHoc;
@@ -59,7 +61,7 @@ BOOST_AUTO_TEST_CASE( hit )
 	tc.add("key", 3, vu);
 	BOOST_REQUIRE_EQUAL(1, tc.size());
 	BOOST_REQUIRE_EQUAL(3, *tc.get("key"));
-	BOOST_REQUIRE_EQUAL(3, tc.getItem("key")->item());
+	BOOST_REQUIRE_EQUAL(3, *tc.getItem("key")->item());
 	BOOST_REQUIRE_EQUAL(vu, tc.getItem("key")->validUntil);
 	BOOST_REQUIRE_EQUAL("key", tc.getItem("key")->key);
 	BOOST_REQUIRE_EQUAL(1, tc.size());
@@ -114,7 +116,7 @@ BOOST_AUTO_TEST_CASE( callcache )
 	int callCount = 0;
 	auto vu = time(NULL) + 5;
 	BOOST_REQUIRE_EQUAL(nullptr, tc.get("key"));
-	tc.add("key", [&callCount]{ callCount++; return 3; }, vu);
+	tc.add("key", TestCache::Factory([&callCount]{ callCount++; return 3; }), vu);
 	BOOST_REQUIRE_EQUAL(0, callCount);
 	BOOST_REQUIRE_EQUAL(3, *tc.get("key"));
 	BOOST_REQUIRE_EQUAL(1, callCount);
@@ -122,3 +124,39 @@ BOOST_AUTO_TEST_CASE( callcache )
 	BOOST_REQUIRE_EQUAL(1, callCount);
 }
 
+BOOST_AUTO_TEST_CASE( pointercallcache )
+{
+	TestCache tc;
+	int callCount = 0;
+	auto vu = time(NULL) + 5;
+	BOOST_REQUIRE_EQUAL(nullptr, tc.get("key"));
+	tc.add("key", TestCache::PointerFactory([&callCount]{ callCount++; return TestCache::Value(new Obj(3)); }), vu);
+	BOOST_REQUIRE_EQUAL(0, callCount);
+	BOOST_REQUIRE_EQUAL(3, *tc.get("key"));
+	BOOST_REQUIRE_EQUAL(1, callCount);
+	BOOST_REQUIRE_EQUAL(3, *tc.get("key"));
+	BOOST_REQUIRE_EQUAL(1, callCount);
+}
+
+BOOST_AUTO_TEST_CASE( hitThenRenove )
+{
+	TestCache tc;
+	tc.add("key", 3, time(NULL) + 5);
+	auto h = tc.get("key");
+	BOOST_REQUIRE(h);
+	BOOST_REQUIRE_EQUAL(3, *h);
+	tc.remove("key");
+	BOOST_REQUIRE(!tc.get("key"));
+	BOOST_REQUIRE_EQUAL(3, *h);
+}
+
+BOOST_AUTO_TEST_CASE( addPointer )
+{
+	TestCache tc;
+	auto v = TestCache::Value(new Obj(3));
+	tc.add("key", v, time(NULL) + 1);
+	auto h = tc.get("key");
+	BOOST_REQUIRE(h);
+	BOOST_REQUIRE_EQUAL(3, *h);
+}
+
-- 
cgit v1.2.3