diff options
-rw-r--r-- | lib/worker.cpp | 3 | ||||
-rw-r--r-- | lib/worker.h | 87 | ||||
-rw-r--r-- | test/test-worker.cpp | 39 |
3 files changed, 116 insertions, 13 deletions
diff --git a/lib/worker.cpp b/lib/worker.cpp index 4f1352d..7e7f296 100644 --- a/lib/worker.cpp +++ b/lib/worker.cpp @@ -1,5 +1,4 @@ #include "worker.h" -#include "work.h" #include <algorithm> #include <iterator> #include <mutex> @@ -19,7 +18,7 @@ Worker::~Worker() } void -Worker::addWork(WorkPtr j) +Worker::addWorkPtr(WorkPtr j) { std::lock_guard<std::mutex> lck {todoMutex}; todoLen.release(); diff --git a/lib/worker.h b/lib/worker.h index 96593d9..5356606 100644 --- a/lib/worker.h +++ b/lib/worker.h @@ -1,6 +1,8 @@ #pragma once #include <deque> +#include <functional> +#include <future> #include <memory> #include <mutex> #include <semaphore> @@ -9,28 +11,93 @@ #include <utility> #include <vector> -class Work; class Worker { +public: + class WorkItem { + public: + WorkItem() = default; + virtual ~WorkItem() = default; + NO_MOVE(WorkItem); + NO_COPY(WorkItem); + + virtual void doWork() = 0; + }; + + template<typename T> class WorkItemT : public WorkItem { + public: + T + get() + { + return future.get(); + } + + protected: + std::promise<T> promise; + std::future<T> future {promise.get_future()}; + friend Worker; + }; + + template<typename... Params> + static auto + addWork(Params &&... params) + { + return instance.addWorkImpl(std::forward<Params>(params)...); + } + template<typename T> using WorkPtrT = std::shared_ptr<WorkItemT<T>>; + private: + template<typename T, typename... Params> class WorkItemTImpl : public WorkItemT<T> { + public: + WorkItemTImpl(Params &&... params) : params {std::forward<Params>(params)...} { } + + private: + void + doWork() override + { + try { + if constexpr (std::is_void_v<T>) { + std::apply( + [](auto &&... p) { + return std::invoke(p...); + }, + params); + WorkItemT<T>::promise.set_value(); + } + else { + WorkItemT<T>::promise.set_value(std::apply( + [](auto &&... p) { + return std::invoke(p...); + }, + params)); + } + } + catch (...) { + WorkItemT<T>::promise.set_exception(std::current_exception()); + } + } + + std::tuple<Params...> params; + }; + Worker(); ~Worker(); NO_COPY(Worker); NO_MOVE(Worker); - using WorkPtr = std::unique_ptr<Work>; + using WorkPtr = std::shared_ptr<WorkItem>; - template<typename T, typename... Params> - void - addWork(Params &&... params) - requires std::is_base_of_v<Work, T> + template<typename... Params> + auto + addWorkImpl(Params &&... params) { - addWork(std::make_unique<T>(std::forward<Params>(params)...)); + using T = decltype(std::invoke(std::forward<Params>(params)...)); + auto work = std::make_shared<WorkItemTImpl<T, Params...>>(std::forward<Params>(params)...); + addWorkPtr(work); + return work; } - void addWork(WorkPtr w); - -private: + void addWorkPtr(WorkPtr w); void worker(); using Threads = std::vector<std::jthread>; diff --git a/test/test-worker.cpp b/test/test-worker.cpp index 3c5ed7e..c542020 100644 --- a/test/test-worker.cpp +++ b/test/test-worker.cpp @@ -2,6 +2,43 @@ #include "testHelpers.h" #include <boost/test/unit_test.hpp> +#include <set> #include <stream_support.hpp> +#include <worker.h> -BOOST_AUTO_TEST_CASE(exists) { } +uint32_t +workCounter() +{ + static std::atomic_uint32_t n; + usleep(1000); + return n++; +} + +BOOST_AUTO_TEST_CASE(basic_slow_counter) +{ + std::vector<Worker::WorkPtrT<uint32_t>> ps; + for (int i {}; i < 30; ++i) { + ps.push_back(Worker::addWork(workCounter)); + } + std::set<uint32_t> out; + std::transform(ps.begin(), ps.end(), std::inserter(out, out.end()), [](auto && p) { + return p->get(); + }); + BOOST_REQUIRE_EQUAL(out.size(), ps.size()); + BOOST_CHECK_EQUAL(*out.begin(), 0); + BOOST_CHECK_EQUAL(*out.rbegin(), ps.size() - 1); +} + +BOOST_AUTO_TEST_CASE(basic_error_handler) +{ + auto workitem = Worker::addWork([]() { + throw std::runtime_error {"test"}; + }); + BOOST_CHECK_THROW(workitem->get(), std::runtime_error); +} + +BOOST_AUTO_TEST_CASE(basic_void_work) +{ + auto workitem = Worker::addWork([]() {}); + BOOST_CHECK_NO_THROW(workitem->get()); +} |