// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include "SafeQueue.h" #include "log/Log.h" namespace milvus { const int DEFAULT_CPU_NUM = 1; const int64_t DEFAULT_HIGH_PRIORITY_THREAD_CORE_COEFFICIENT = 10; const int64_t DEFAULT_MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT = 5; const int64_t DEFAULT_LOW_PRIORITY_THREAD_CORE_COEFFICIENT = 1; extern std::atomic HIGH_PRIORITY_THREAD_CORE_COEFFICIENT; extern std::atomic MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT; extern std::atomic LOW_PRIORITY_THREAD_CORE_COEFFICIENT; extern int CPU_NUM; void SetHighPriorityThreadCoreCoefficient(const float coefficient); void SetMiddlePriorityThreadCoreCoefficient(const float coefficient); void SetLowPriorityThreadCoreCoefficient(const float coefficient); void InitCpuNum(const int core); class ThreadPool { public: explicit ThreadPool(const float thread_core_coefficient, std::string name) : shutdown_(false), name_(std::move(name)) { idle_threads_size_ = 0; current_threads_size_ = 0; min_threads_size_ = 1; max_threads_size_.store(std::max( 1, static_cast(std::round(CPU_NUM * thread_core_coefficient)))); // only IO pool will set large limit, but the CPU helps nothing to IO operations, // we need to limit the max thread num, each thread will download 16~64 MiB data, // according to our benchmark, 16 threads is enough to saturate the network bandwidth. if (max_threads_size_.load() > 16) { max_threads_size_.store(16); } LOG_INFO("Init thread pool:{}", name_) << " with min worker num:" << min_threads_size_ << " and max worker num:" << max_threads_size_.load(); Init(); } ~ThreadPool() { ShutDown(); } ThreadPool(const ThreadPool&) = delete; ThreadPool(ThreadPool&&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; ThreadPool& operator=(ThreadPool&&) = delete; void Init(); void ShutDown(); size_t GetThreadNum() { std::lock_guard lock(mutex_); return current_threads_size_; } size_t GetMaxThreadNum() { return max_threads_size_.load(); } template auto Submit(F&& f, Args&&... args) -> std::future { std::function func = std::bind(std::forward(f), std::forward(args)...); auto task_ptr = std::make_shared>(func); std::function wrap_func = [task_ptr]() { (*task_ptr)(); }; work_queue_.enqueue(wrap_func); std::lock_guard lock(mutex_); if (idle_threads_size_ > 0) { condition_lock_.notify_one(); } else if (current_threads_size_ < max_threads_size_.load()) { // Dynamic increase thread number std::thread t(&ThreadPool::Worker, this); assert(threads_.find(t.get_id()) == threads_.end()); threads_[t.get_id()] = std::move(t); current_threads_size_++; } return task_ptr->get_future(); } void Worker(); void FinishThreads(); void Resize(int new_size) { //no need to hold mutex here as we don't require //max_threads_size to take effect instantly, just guaranteed atomic max_threads_size_.store(new_size); } public: int min_threads_size_; int idle_threads_size_; int current_threads_size_; std::atomic max_threads_size_; bool shutdown_; static constexpr size_t WAIT_SECONDS = 2; SafeQueue> work_queue_; std::unordered_map threads_; SafeQueue need_finish_threads_; std::mutex mutex_; std::condition_variable condition_lock_; std::string name_; }; } // namespace milvus