diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp index a4693d8035..cbd4f4f09c 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp @@ -32,10 +32,6 @@ namespace knowhere { IndexModelPtr GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); - if (build_cfg->metric_type == knowhere::METRICTYPE::IP) { - KNOWHERE_LOG_ERROR << "PQ not support IP in GPU version!"; - throw KnowhereException("PQ not support IP in GPU version!"); - } if (build_cfg != nullptr) { build_cfg->CheckValid(); // throw exception } diff --git a/core/src/wrapper/ConfAdapter.cpp b/core/src/wrapper/ConfAdapter.cpp index 6b1667f9d7..0214025ed7 100644 --- a/core/src/wrapper/ConfAdapter.cpp +++ b/core/src/wrapper/ConfAdapter.cpp @@ -18,6 +18,7 @@ #include "wrapper/ConfAdapter.h" #include "WrapperException.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "server/Config.h" #include "utils/Log.h" #include @@ -130,6 +131,17 @@ IVFPQConfAdapter::Match(const TempMetaConf& metaconf) { conf->nbits = 8; MatchBase(conf); +#ifdef MILVUS_GPU_VERSION + Status s; + bool enable_gpu = false; + server::Config& config = server::Config::GetInstance(); + s = config.GetGpuResourceConfigEnable(enable_gpu); + if (s.ok() && conf->metric_type == knowhere::METRICTYPE::IP) { + WRAPPER_LOG_ERROR << "PQ not support IP in GPU version!"; + throw WrapperException("PQ not support IP in GPU version!"); + } +#endif + /* * Faiss 1.6 * Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with diff --git a/core/src/wrapper/gpu/GPUVecImpl.cpp b/core/src/wrapper/gpu/GPUVecImpl.cpp index 167a4d6a98..500bd61b9b 100644 --- a/core/src/wrapper/gpu/GPUVecImpl.cpp +++ b/core/src/wrapper/gpu/GPUVecImpl.cpp @@ -59,7 +59,7 @@ IVFMixIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, co } } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); - throw WrapperException(e.what()); + return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what());