diff --git a/core/src/scheduler/selector/FaissIVFFlatPass.cpp b/core/src/scheduler/selector/FaissIVFFlatPass.cpp index 9c0860f929..f6f6b99790 100644 --- a/core/src/scheduler/selector/FaissIVFFlatPass.cpp +++ b/core/src/scheduler/selector/FaissIVFFlatPass.cpp @@ -12,6 +12,8 @@ #include "scheduler/selector/FaissIVFFlatPass.h" #include "cache/GpuCacheMgr.h" #include "config/ServerConfig.h" +#include "faiss/gpu/utils/DeviceUtils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "scheduler/SchedInst.h" #include "scheduler/Utils.h" #include "scheduler/task/SearchTask.h" @@ -54,7 +56,11 @@ FaissIVFFlatPass::Run(const TaskPtr& task) { LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: gpu disable, specify cpu to search!"); res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); } else if (search_task->nq() < threshold_) { - LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nq < gpu_search_threshold, specify cpu to search!"); + LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nq < gpu_search_threshold, specify cpu to search! "); + res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + } else if (search_task->ExtraParam()[knowhere::IndexParams::nprobe].get() > + faiss::gpu::getMaxKSelection()) { + LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nprobe > gpu_max_nprobe_threshold, specify cpu to search!"); res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); } else { LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nq >= gpu_search_threshold, specify gpu %d to search!", diff --git a/core/src/scheduler/selector/FaissIVFPQPass.cpp b/core/src/scheduler/selector/FaissIVFPQPass.cpp index 4dae83a3ca..3bfe1a2081 100644 --- a/core/src/scheduler/selector/FaissIVFPQPass.cpp +++ b/core/src/scheduler/selector/FaissIVFPQPass.cpp @@ -12,6 +12,8 @@ #include "scheduler/selector/FaissIVFPQPass.h" #include "cache/GpuCacheMgr.h" #include "config/ServerConfig.h" +#include "faiss/gpu/utils/DeviceUtils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "scheduler/SchedInst.h" #include "scheduler/Utils.h" #include "scheduler/task/SearchTask.h" @@ -58,6 +60,10 @@ FaissIVFPQPass::Run(const TaskPtr& task) { } else if (search_task->nq() < threshold_) { LOG_SERVER_DEBUG_ << LogOut("FaissIVFPQPass: nq < gpu_search_threshold, specify cpu to search!"); res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + } else if (search_task->ExtraParam()[knowhere::IndexParams::nprobe].get() > + faiss::gpu::getMaxKSelection()) { + LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nprobe > gpu_max_nprobe_threshold, specify cpu to search!"); + res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); } else { LOG_SERVER_DEBUG_ << LogOut("FaissIVFPQPass: nq >= gpu_search_threshold, specify gpu %d to search!", search_gpus_[idx_]); diff --git a/core/src/scheduler/selector/FaissIVFSQ8HPass.cpp b/core/src/scheduler/selector/FaissIVFSQ8HPass.cpp index 71a9485145..5dc9fb1760 100644 --- a/core/src/scheduler/selector/FaissIVFSQ8HPass.cpp +++ b/core/src/scheduler/selector/FaissIVFSQ8HPass.cpp @@ -13,6 +13,8 @@ #include "scheduler/selector/FaissIVFSQ8HPass.h" #include "cache/GpuCacheMgr.h" #include "config/ServerConfig.h" +#include "faiss/gpu/utils/DeviceUtils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "scheduler/SchedInst.h" #include "scheduler/Utils.h" #include "scheduler/task/SearchTask.h" @@ -55,6 +57,10 @@ FaissIVFSQ8HPass::Run(const TaskPtr& task) { } else if (search_task->nq() < threshold_) { LOG_SERVER_DEBUG_ << LogOut("FaissIVFSQ8HPass: nq < gpu_search_threshold, specify cpu to search!"); res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + } else if (search_task->ExtraParam()[knowhere::IndexParams::nprobe].get() > + faiss::gpu::getMaxKSelection()) { + LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nprobe > gpu_max_nprobe_threshold, specify cpu to search!"); + res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); } else { LOG_SERVER_DEBUG_ << LogOut("FaissIVFSQ8HPass: nq >= gpu_search_threshold, specify gpu %d to search!", search_gpus_[idx_]); diff --git a/core/src/scheduler/selector/FaissIVFSQ8Pass.cpp b/core/src/scheduler/selector/FaissIVFSQ8Pass.cpp index 752c1a1918..0496ccd489 100644 --- a/core/src/scheduler/selector/FaissIVFSQ8Pass.cpp +++ b/core/src/scheduler/selector/FaissIVFSQ8Pass.cpp @@ -12,6 +12,8 @@ #include "scheduler/selector/FaissIVFSQ8Pass.h" #include "cache/GpuCacheMgr.h" #include "config/ServerConfig.h" +#include "faiss/gpu/utils/DeviceUtils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "scheduler/SchedInst.h" #include "scheduler/Utils.h" #include "scheduler/task/SearchTask.h" @@ -56,6 +58,10 @@ FaissIVFSQ8Pass::Run(const TaskPtr& task) { } else if (search_task->nq() < threshold_) { LOG_SERVER_DEBUG_ << LogOut("FaissIVFSQ8Pass: nq < gpu_search_threshold, specify cpu to search!"); res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); + } else if (search_task->ExtraParam()[knowhere::IndexParams::nprobe].get() > + faiss::gpu::getMaxKSelection()) { + LOG_SERVER_DEBUG_ << LogOut("FaissIVFFlatPass: nprobe > gpu_max_nprobe_threshold, specify cpu to search!"); + res_ptr = ResMgrInst::GetInstance()->GetResource("cpu"); } else { LOG_SERVER_DEBUG_ << LogOut("FaissIVFSQ8Pass: nq >= gpu_search_threshold, specify gpu %d to search!", search_gpus_[idx_]);