diff --git a/internal/core/src/indexbuilder/IndexWrapper.cpp b/internal/core/src/indexbuilder/IndexWrapper.cpp index 1a26404aa2..6a5db6b1c6 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.cpp +++ b/internal/core/src/indexbuilder/IndexWrapper.cpp @@ -19,6 +19,7 @@ #include "utils/EasyAssert.h" #include "IndexWrapper.h" #include "indexbuilder/utils.h" +#include "index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h" namespace milvus { namespace indexbuilder { @@ -29,14 +30,11 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria parse(); - std::map mode_map = {{"CPU", knowhere::IndexMode::MODE_CPU}, - {"GPU", knowhere::IndexMode::MODE_GPU}}; - auto mode = get_config_by_name("index_mode"); - auto index_mode = mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU; - + auto index_mode = get_index_mode(); auto index_type = get_index_type(); auto metric_type = get_metric_type(); AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type); + index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode); Assert(index_ != nullptr); } @@ -157,6 +155,11 @@ IndexWrapper::dim() { void IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { auto index_type = get_index_type(); + auto index_mode = get_index_mode(); + config_[knowhere::meta::ROWS] = dataset->Get(knowhere::meta::ROWS); + auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type); + AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!"); + if (is_in_need_id_list(index_type)) { PanicInfo(std::string(index_type) + " doesn't support build without ids yet!"); } @@ -176,6 +179,11 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) { void IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) { Assert(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end()); + auto index_type = get_index_type(); + auto index_mode = get_index_mode(); + config_[knowhere::meta::ROWS] = dataset->Get(knowhere::meta::ROWS); + auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type); + AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!"); // index_->Train(dataset, config_); // index_->Add(dataset, config_); index_->BuildAll(dataset, config_); @@ -281,6 +289,16 @@ IndexWrapper::get_metric_type() { } } +knowhere::IndexMode +IndexWrapper::get_index_mode() { + static std::map mode_map = { + {"CPU", knowhere::IndexMode::MODE_CPU}, + {"GPU", knowhere::IndexMode::MODE_GPU}, + }; + auto mode = get_config_by_name("index_mode"); + return mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU; +} + std::unique_ptr IndexWrapper::Query(const knowhere::DatasetPtr& dataset) { return std::move(QueryImpl(dataset, config_)); diff --git a/internal/core/src/indexbuilder/IndexWrapper.h b/internal/core/src/indexbuilder/IndexWrapper.h index 979a8fa960..8bf2ed881c 100644 --- a/internal/core/src/indexbuilder/IndexWrapper.h +++ b/internal/core/src/indexbuilder/IndexWrapper.h @@ -62,6 +62,9 @@ class IndexWrapper { std::string get_metric_type(); + knowhere::IndexMode + get_index_mode(); + template std::optional get_config_by_name(std::string name); diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 217372700b..e01d989897 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -35,7 +35,7 @@ CreateIndex(const char* serialized_type_params, const char* serialized_index_par *res_index = index.release(); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -59,7 +59,7 @@ BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* cIndex->BuildWithoutIds(ds); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -77,7 +77,7 @@ BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* ve cIndex->BuildWithoutIds(ds); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -94,7 +94,7 @@ SerializeToSlicedBuffer(CIndex index, int32_t* buffer_size, char** res_buffer) { *res_buffer = binary.data; status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -109,7 +109,7 @@ LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, in cIndex->Load(serialized_sliced_blob_buffer, size); status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -129,7 +129,7 @@ QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -153,7 +153,7 @@ QueryOnFloatVecIndexWithParam(CIndex index, status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -173,7 +173,7 @@ QueryOnBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors, C status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -197,7 +197,7 @@ QueryOnBinaryVecIndexWithParam(CIndex index, status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -213,7 +213,7 @@ CreateQueryResult(CIndexQueryResult* res) { status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } @@ -259,7 +259,7 @@ DeleteQueryResult(CIndexQueryResult res) { status.error_code = Success; status.error_msg = ""; - } catch (std::runtime_error& e) { + } catch (std::exception& e) { status.error_code = UnexpectedException; status.error_msg = strdup(e.what()); } diff --git a/internal/indexbuilder/index.go b/internal/indexbuilder/index.go index eba049e543..16439c4b29 100644 --- a/internal/indexbuilder/index.go +++ b/internal/indexbuilder/index.go @@ -106,10 +106,13 @@ func (index *CIndex) BuildFloatVecIndexWithoutIds(vectors []float32) error { CStatus BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors); */ + fmt.Println("before BuildFloatVecIndexWithoutIds") status := C.BuildFloatVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0])) errorCode := status.error_code + fmt.Println("BuildFloatVecIndexWithoutIds error code: ", errorCode) if errorCode != 0 { errorMsg := C.GoString(status.error_msg) + fmt.Println("BuildFloatVecIndexWithoutIds error msg: ", errorMsg) defer C.free(unsafe.Pointer(status.error_msg)) return errors.New("BuildFloatVecIndexWithoutIds failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 425cae75cf..d676a067dd 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -440,7 +440,7 @@ func (qt *QueryTask) PostExecute() error { hits := make([][]*servicepb.Hits, 0) for _, partialSearchResult := range filterSearchResult { - if len(partialSearchResult.Hits) <= 0 { + if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 { filterReason += "nq is zero\n" continue }