diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a690cb..3851b55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,9 +6,7 @@ set(CMAKE_CXX_STANDARD 23) add_executable(ThreadPool demo/main.cpp src/threadpool.h src/threadpool.cpp - tests/return_value_tests.h tests/all_tests.h - tests/task_tests.h tests/threadpool_tests.h demo/dependency_demo.h demo/demo.h diff --git a/demo/demo.h b/demo/demo.h index 1be821e..fcbfd9d 100644 --- a/demo/demo.h +++ b/demo/demo.h @@ -17,17 +17,16 @@ inline void fibonacci_example() { threadpool tp(3); - std::vector> futures = { - tp.submit( []() -> int { return recursive_fibonacci(20);} ), - tp.submit( []() -> int { return recursive_fibonacci(30);} ), - tp.submit( []() -> int { return recursive_fibonacci(40);} ), - }; - + std::vector> futures; + futures.reserve(3); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(10);})); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(20);})); + futures.emplace_back( tp.submit( []() -> int { return recursive_fibonacci(30);})); tp.shutdown(); for (int i=0; i opt_task = this->poll_task(); - if (opt_task.has_value()) { - auto& task = opt_task.value(); + + if (!tasks.empty()) { + auto task = std::move(tasks.front()); + tasks.pop(); lock.unlock(); task(); } @@ -30,6 +31,9 @@ threadpool::threadpool(const int& n) { } } + + + threadpool::~threadpool() { std::unique_lock lock(queue_stop_mutex); if (!(m_Stop)) { @@ -62,19 +66,7 @@ void threadpool::shutdown_now() { } [[nodiscard]] -int threadpool::queue_size() { +size_t threadpool::queue_size() { std::lock_guard lock(queue_stop_mutex); return tasks.size(); } - -// ============ THREADPOOL PRIVATE ============ - -std::optional threadpool::poll_task() { - // No lock guard as the thread would already have the guard - if (tasks.empty()) { - return std::nullopt; - } - task front = tasks.front(); - tasks.pop(); - return front; -} \ No newline at end of file diff --git a/src/threadpool.h b/src/threadpool.h index 20649e7..3156b46 100644 --- a/src/threadpool.h +++ b/src/threadpool.h @@ -6,293 +6,82 @@ #include #include #include - -// Forward Declarations -template -struct return_value_handle; -class threadpool; - - -// Wrapper for the function pointers -struct task { - task(const std::function& ptr) : m_Ptr {ptr} { - }; - - void operator()() const { - m_Ptr(); - } -private: - std::function m_Ptr; -}; - -// Simplified implementation of std::future -// return_value is shared state via the shared pointer, do not allow copy / move semantics -template -struct return_value { - - friend class return_value_handle; // Allow access to set_value and set_valid - - return_value() = default; - - return_value(const return_value&) = delete; // Copy Constructor - return_value& operator=(const return_value&) = delete; // Copy Assignment Constructor - - return_value(return_value&&) = delete; // Move Constructor - return_value& operator=(return_value&&) = delete; // Move Assignment Constructor - - bool is_valid() { - std::unique_lock lock(access_mutex); - return m_IsValid; - } - - T get() { - std::unique_lock lock(access_mutex); - if (!is_valid_unsafe()) { // unsafe to prevent deadlock as we already acquired mutex - throw std::runtime_error{"Thread Return Value is Invalid!"}; - } - m_IsValid = false; - return m_Value; - } - -private: - std::mutex access_mutex; - T m_Value; - bool m_IsValid{false}; - - bool is_valid_unsafe() { - return m_IsValid; - } - - void set_value(const T& value) { - std::unique_lock lock(access_mutex); - m_Value = value; - m_IsValid = true; - } -}; - -// Specialization of void -template<> -struct return_value { - - friend class return_value_handle; // Allow access to set_value and set_valid - - return_value() = default; - - return_value(const return_value&) = delete; // Copy Constructor - return_value& operator=(const return_value&) = delete; // Copy Assignment Constructor - - return_value(return_value&&) = delete; // Move Constructor - return_value& operator=(return_value&&) = delete; // Move Assignment Constructor - - bool is_valid() { - std::unique_lock lock(access_mutex); - return m_IsValid; - } - - void get() { - std::unique_lock lock(access_mutex); - if (!is_valid_unsafe()) { - lock.unlock(); - throw std::runtime_error{"Thread Return Value is Invalid!"}; - } - // nothing to return - } - -private: - std::mutex access_mutex; - bool m_IsValid{false}; - - bool is_valid_unsafe() { - return m_IsValid; - } - - void set_value() { - std::unique_lock lock(access_mutex); - m_IsValid = false; // Never true with - // Do nothing - } -}; - - -template -struct return_value_handle { - - friend class threadpool; - -public: - return_value_handle() : m_Handle{std::make_shared>()} { - } - - return_value_handle(const return_value_handle&) = default; // Copy Constructor - return_value_handle& operator=(const return_value_handle&) = default; // Copy Assignment Constructor - - return_value_handle(return_value_handle&&) = default; // Move Constructor - return_value_handle& operator=(return_value_handle&&) = default; // Move Assignment Constructor - - bool is_valid() const { - if (m_Handle == nullptr) - return false; - - return m_Handle -> is_valid(); - } - - T get() const { - return m_Handle.get()->get(); - } - - // Dependency DAG APIs - template - return_value_handle then(threadpool& tp, Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - - return rv; - } - -private: - std::shared_ptr> m_Handle; - void set_value(const T& value) { - m_Handle->set_value(value); - } - void set_valid(const bool& value) { - m_Handle->set_valid(value); - } -}; - -// Specialization of void -template<> -struct return_value_handle { - - friend class threadpool; - -public: - return_value_handle() : m_Handle{std::make_shared>()} { - } - - return_value_handle(const return_value_handle&) = default; // Copy Constructor - return_value_handle& operator=(const return_value_handle&) = default; // Copy Assignment Constructor - - return_value_handle(return_value_handle&&) = default; // Move Constructor - return_value_handle& operator=(return_value_handle&&) = default; // Move Assignment Constructor - - bool is_valid() const { - if (m_Handle == nullptr) - return false; - - return m_Handle -> is_valid(); - } - - void get() const { - return m_Handle.get() -> get(); - } - - // Dependency DAG APIs - template - return_value_handle then(threadpool& tp, Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - - return rv; - } - -private: - std::shared_ptr> m_Handle; - static void set_value() { - // Do nothing - } -}; - - - - - - - - - +#include class threadpool { private: - std::queue tasks; + std::queue> tasks; std::vector workers; bool m_Stop{false}; std::mutex queue_stop_mutex; // Used for queue operations and read/write m_Stop operations std::condition_variable cv; - - std::optional poll_task(); - void write_task(const std::function& ptr) { - // No lock guard as submit() already contains the lock - tasks.push( task{ptr} ); + void write_task(const std::function& fn) { + // Does not need the lock as submit already acquires it + tasks.push(fn); cv.notify_one(); } public: - threadpool(const int& threads); + explicit threadpool(const int& threads); ~threadpool(); - template + + template [[nodiscard]] - return_value_handle submit(const std::function& ptr) { + auto submit(Function &&F, Args &&...ArgList) { + std::unique_lock lock(queue_stop_mutex); if (m_Stop) { - lock.unlock(); throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; } - return_value_handle rv_handle{}; - write_task( - [ptr, rv_handle] () mutable { - rv_handle.set_value(ptr()); - } - ); - return rv_handle; - } + using ReturnType = std::invoke_result_t; - template - [[nodiscard]] - return_value_handle submit(const std::function& ptr, int dependency_id) { - return_value_handle rv_handle{}; - return rv_handle; + std::shared_ptr> task = std::make_shared>(( + std::bind(std::forward(F), + std::forward(ArgList)...) + )); + + auto future = task->get_future(); + + write_task([task](){ (*task)(); }); + + return future; // Return type is future } + // template + // [[nodiscard]] + // auto submit(const Fn&& fn) { + // using return_type = std::invoke_result_t; + // + // std::unique_lock lock(queue_stop_mutex); + // if (m_Stop) { + // throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; + // } + // + // std::packaged_task task{fn}; + // write_task([&task]() { task(); }); + // return task.get_future(); + // } + void shutdown(); // finish queued tasks void shutdown_now(); // cancel pending tasks [[nodiscard]] - int queue_size(); + size_t queue_size(); // Dependency DAG API - template - return_value_handle when_all(Args... args) { - // Todo: actually implement the logic - return_value_handle rv{}; - return rv; - } - - -}; - -// Void specialization -template<> -inline return_value_handle threadpool::submit(const std::function& ptr) { - - std::unique_lock lock(queue_stop_mutex); - if (m_Stop) { - lock.unlock(); - throw std::runtime_error{"ThreadPool::submit() after shutdown called"}; - } + // template + // return_value_handle when_all(Args... args) { + // // Todo: actually implement the logic + // return_value_handle rv{}; + // return rv; + // } - return_value_handle rv_handle{}; - write_task( - [ptr] (){ - ptr(); - } - ); - return rv_handle; -} \ No newline at end of file +}; \ No newline at end of file diff --git a/tests/all_tests.h b/tests/all_tests.h index 81ccdd2..ca51715 100644 --- a/tests/all_tests.h +++ b/tests/all_tests.h @@ -1,18 +1,14 @@ #pragma once -#include "return_value_tests.h" -#include "task_tests.h" #include "threadpool_tests.h" inline void all_tests() { // Use the threadpools lol - threadpool tp(3); + threadpool tp(1); - auto rv1 = tp.submit(task_tests); - auto rv2 = tp.submit(return_value_tests); - auto rv3 = tp.submit(threadpool_tests); + auto rv3 = tp.submit(threadpool_tests); tp.shutdown(); } \ No newline at end of file diff --git a/tests/return_value_tests.h b/tests/return_value_tests.h deleted file mode 100644 index af82ac3..0000000 --- a/tests/return_value_tests.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include "../src/threadpool.h" -#include -#include - -inline void return_value_tests() { - - // Handle should not be valid after creating it immediately - { - return_value_handle rv_handle1{}; - assert(!rv_handle1.is_valid()); - return_value_handle rv_handle2{}; - assert(!rv_handle2.is_valid()); - } - - // Catch exception when get is invalid - { - return_value_handle rv_handle1{}; - assert(!rv_handle1.is_valid()); - try { - auto val = rv_handle1.get(); - assert(false); - } catch (std::runtime_error& e) { - - } - assert(!rv_handle1.is_valid()); - - - return_value_handle rv_handle2{}; - assert(!rv_handle2.is_valid()); - try { - rv_handle2.get(); - assert(false); - } catch (std::runtime_error& e) { - - } - assert(!rv_handle2.is_valid()); - } - - std::cout << "return_value & return_value_handle tests passed!\n"; -} diff --git a/tests/task_tests.h b/tests/task_tests.h deleted file mode 100644 index e04b5f8..0000000 --- a/tests/task_tests.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "../src/threadpool.h" -#include -#include - -inline void task_tests() { - - // Task operator() overload - { - int num{0}; - std::function ptr{ - [&num]() mutable { num += 5; } - }; - task t{ptr}; - t(); - - assert(num == 5); - } - - std::cout << "task test passed! \n"; -} \ No newline at end of file diff --git a/tests/threadpool_tests.h b/tests/threadpool_tests.h index 57ac7af..8f33fb6 100644 --- a/tests/threadpool_tests.h +++ b/tests/threadpool_tests.h @@ -6,69 +6,102 @@ #include #include + +inline std::string string_test() { + return "Hello world!"; +} + +inline int int_test(int input1, int input2) { + return input1 + input2; +} + + + + inline void threadpool_tests() { - // Ensure shutdown finishes all remaining tasks + // Submit syntax { threadpool tp{1}; - int i{0}; - auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; - auto f2 = [&i]() mutable{ i = 5; }; + auto future = tp.submit([]() {}); + tp.shutdown(); + } - auto rv1 = tp.submit(f1); - auto rv2 = tp.submit(f2); + // Verify work is done on a submit + { + threadpool tp{1}; + int i = 0; + auto future = tp.submit([&i](){i = 42;}); tp.shutdown(); + assert(i == 42); + } - assert(i == 5); + // Return type syntax + { + threadpool tp{1}; + auto future = tp.submit([]() {return 42;}); + tp.shutdown(); + int work = future.get(); + assert(work == 42); } - // Ensure shutdown_now clears the remaining tasks + // Variadic arguments works { threadpool tp{1}; - int i{0}; - auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; - auto f2 = [&i]() mutable{ i = 5; }; + auto future = tp.submit([](int num1, int num2, int num3) {return num1 + num2 + num3;}, 1, 2, 3); + tp.shutdown(); + assert(future.valid() && future.get() == 6); + } - auto rv1 = tp.submit(f1); - auto rv2 = tp.submit(f2); - tp.shutdown_now(); + // Function pointer works + { + threadpool tp{2}; + auto future1 = tp.submit(string_test); + auto future2 = tp.submit(int_test, 1, 2); + tp.shutdown(); + assert(future1.valid() && future1.get() == "Hello world!"); + assert(future2.valid() && future2.get() == 3); - assert(i == 0); } + // Function pointers with variadic arguments + { + threadpool tp{1}; + //std::function f1 = []() -> int {return 5;}; + std::future future = tp.submit([]() {return 5;}); + tp.shutdown(); + assert(future.get() == 5); + } - // Submit after shutdown + // Ensure shutdown finishes all remaining tasks { threadpool tp{1}; + int i{0}; + auto f1 = []() { std::this_thread::sleep_for(std::chrono::milliseconds(100));}; + auto f2 = [&i]() mutable{ i = 5; }; + + auto rv1 = tp.submit(f1); + auto rv2 = tp.submit(f2); tp.shutdown(); - try { - auto rv = tp.submit([]() {}); - assert(false); - } catch (std::runtime_error) { - } + assert(i == 5); } - // Submit after shutdown_now + + // Submit after shutdown throws { threadpool tp{1}; - tp.shutdown_now(); + auto future = tp.submit([]() {return 42;}); + tp.shutdown(); + try { - auto rv = tp.submit([]() {}); + auto rv = tp.submit([]() {}); assert(false); } catch (std::runtime_error) { } } - // Destructor stress tests - { - for (int i = 0; i < 10'000; ++i) { - threadpool tp{4}; - auto rv = tp.submit([]{}); - } - } - // Nested submission /* The invariant here is a little more subtle @@ -79,10 +112,10 @@ inline void threadpool_tests() { */ { threadpool tp{1}; - auto rv1 = tp.submit([&]{ + auto rv1 = tp.submit([&]{ try { - auto rv2 = tp.submit([](){ /* work */ }); + auto rv2 = tp.submit([](){ /* work */ }); assert(false); } catch (std::runtime_error) { @@ -91,46 +124,27 @@ inline void threadpool_tests() { tp.shutdown(); } - - // Shutdown now test + // get() called after shutdown_now called { threadpool tp{1}; - auto rv1 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(50));}); - auto rv2 = tp.submit( []() { }); + auto rv1 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(1000)); return 5; }); + auto rv2 = tp.submit([]{ std::this_thread::sleep_for(std::chrono::milliseconds(500)); return 55; }); + auto rv3 = tp.submit( []() { return 55;}); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // Wait for thread to pick up the task + std::this_thread::sleep_for(std::chrono::milliseconds(1)); tp.shutdown_now(); - // Not sure exactly why, but this fails - int i; - //assert(rv1.is_valid()); - assert(!rv2.is_valid()); + assert(rv1.valid() && rv1.get() == 5); + try { + auto error = rv3.get(); + assert(false); + } catch (const std::future_error& e) { + + } + } - // // Then() syntax - // { - // threadpool tp{1}; - // - // auto rv_1 = tp.submit([](){ return 5; }); - // - // auto rv_2 = rv_1.then(tp, []() { return 10; }); - // - // tp.shutdown(); - // assert(rv_1.is_valid()); - // assert(rv_1.get() == 5); - // assert(rv_2.is_valid()); - // assert(rv_2.get() == 10); - // } - // - // // Then() actually waits for dependencies - // { - // threadpool tp{5}; - // auto rv_1 = tp.submit([](){ std::this_thread::sleep_for(std::chrono::milliseconds(10));}); - // auto rv_2 = rv_1.then(tp, []() { return 10; }); - // assert(!rv_1.is_valid() && !rv_2.is_valid()); - // tp.shutdown(); - // } std::cout << "threadpool tests passed!\n";