/******************************************************************************* * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved * Unauthorized copying of this file, via any medium is strictly prohibited. * Proprietary and confidential. ******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #define MAX_THREADS_NUM 32 namespace zilliz { namespace vecwise { namespace server { class ThreadPool { public: ThreadPool(size_t threads, size_t queue_size = 1000); template auto enqueue(F &&f, Args &&... args) -> std::future::type>; ~ThreadPool(); private: // need to keep track of threads so we can join them std::vector workers; // the task queue std::queue > tasks; size_t max_queue_size; // synchronization std::mutex queue_mutex; std::condition_variable condition; bool stop; }; // the constructor just launches some amount of workers inline ThreadPool::ThreadPool(size_t threads, size_t queue_size) : max_queue_size(queue_size), stop(false) { for (size_t i = 0; i < threads; ++i) workers.emplace_back( [this] { for (;;) { std::function task; { std::unique_lock lock(this->queue_mutex); this->condition.wait(lock, [this] { return this->stop || !this->tasks.empty(); }); if (this->stop && this->tasks.empty()) return; task = std::move(this->tasks.front()); this->tasks.pop(); } this->condition.notify_all(); task(); } } ); } // add new work item to the pool template auto ThreadPool::enqueue(F &&f, Args &&... args) -> std::future::type> { using return_type = typename std::result_of::type; auto task = std::make_shared >( std::bind(std::forward(f), std::forward(args)...) ); std::future res = task->get_future(); { std::unique_lock lock(queue_mutex); this->condition.wait(lock, [this] { return this->tasks.size() < max_queue_size; }); // don't allow enqueueing after stopping the pool if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); tasks.emplace([task]() { (*task)(); }); } condition.notify_all(); return res; } // the destructor joins all threads inline ThreadPool::~ThreadPool() { { std::unique_lock lock(queue_mutex); stop = true; } condition.notify_all(); for (std::thread &worker: workers) worker.join(); } } } }