From 83cbe0f490093bb18cb169dbc4bd28dfaf8c9421 Mon Sep 17 00:00:00 2001 From: "shengjun.li" Date: Thu, 20 Aug 2020 09:42:04 +0800 Subject: [PATCH] fix hamming (#3338) Signed-off-by: shengjun.li --- .../index/vector_index/IndexBinaryIDMAP.cpp | 93 +++---------------- .../index/vector_index/IndexBinaryIDMAP.h | 10 -- .../index/vector_index/IndexBinaryIVF.cpp | 31 +++---- sdk/examples/binary_vector/src/ClientTest.cpp | 57 ++++-------- sdk/examples/utils/Utils.cpp | 6 +- sdk/examples/utils/Utils.h | 3 +- sdk/include/MilvusApi.h | 1 + 7 files changed, 56 insertions(+), 145 deletions(-) diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index 5a3ebb923a..deeff712eb 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -53,64 +53,13 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - // QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config()); QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); - if (index_->metric_type == faiss::METRIC_Hamming) { - auto pf_dist = (float*)malloc(p_dist_size); - int32_t* pi_dist = (int32_t*)p_dist; - for (int i = 0; i < elems; i++) { - *(pf_dist + i) = (float)(*(pi_dist + i)); - } - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, pf_dist); - free(p_dist); - } else { - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, p_dist); - } - return ret_ds; -} - -#if 0 -DatasetPtr -BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { - if (!index_) { - KNOWHERE_THROW_MSG("index not initialize"); - } - - auto dim = dataset_ptr->Get(meta::DIM); - auto rows = dataset_ptr->Get(meta::ROWS); - auto p_data = dataset_ptr->Get(meta::IDS); - - int64_t k = config[meta::TOPK].get(); - auto elems = rows * k; - size_t p_id_size = sizeof(int64_t) * elems; - size_t p_dist_size = sizeof(float) * elems; - auto p_id = (int64_t*)malloc(p_id_size); - auto p_dist = (float*)malloc(p_dist_size); - - auto* pdistances = (int32_t*)p_dist; - index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_); - - auto ret_ds = std::make_shared(); - if (index_->metric_type == faiss::METRIC_Hamming) { - auto pf_dist = (float*)malloc(p_dist_size); - int32_t* pi_dist = (int32_t*)p_dist; - for (int i = 0; i < elems; i++) { - *(pf_dist + i) = (float)(*(pi_dist + i)); - } - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, pf_dist); - free(p_dist); - } else { - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, p_dist); - } + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); return ret_ds; } -#endif int64_t BinaryIDMAP::Count() { @@ -187,39 +136,25 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data()); } -#if 0 -DatasetPtr -BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { - if (!index_) { - KNOWHERE_THROW_MSG("index not initialize"); - } - - // GETBINARYTENSOR(dataset_ptr) - // auto rows = dataset_ptr->Get(meta::ROWS); - auto p_data = dataset_ptr->Get(meta::IDS); - auto elems = dataset_ptr->Get(meta::DIM); - - size_t p_x_size = sizeof(uint8_t) * elems; - auto p_x = (uint8_t*)malloc(p_x_size); - - index_->get_vector_by_id(1, p_data, p_x, bitset_); - - auto ret_ds = std::make_shared(); - ret_ds->Set(meta::TENSOR, p_x); - return ret_ds; -} -#endif - void BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config) { - int32_t* pdistances = (int32_t*)distances; - auto flat_index = dynamic_cast(index_.get())->index; auto default_type = flat_index->metric_type; if (config.contains(Metric::TYPE)) flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); - index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_); + + int32_t* i_distances = reinterpret_cast(distances); + index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_); + + // if hamming, it need transform int32 to float + if (flat_index->metric_type == faiss::METRIC_Hamming) { + int64_t num = n * k; + for (int64_t i = 0; i < num; i++) { + distances[i] = static_cast(i_distances[i]); + } + } + flat_index->metric_type = default_type; } diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h index ce7da9bf04..41545be6fc 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -50,11 +50,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { DatasetPtr Query(const DatasetPtr&, const Config&) override; -#if 0 - DatasetPtr - QueryById(const DatasetPtr& dataset_ptr, const Config& config) override; -#endif - int64_t Count() override; @@ -66,11 +61,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { return Count() * Dim() / 8; } -#if 0 - DatasetPtr - GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override; -#endif - virtual const uint8_t* GetRawVectors(); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp index bc06346e63..3153afd14a 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -62,19 +62,9 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config); auto ret_ds = std::make_shared(); - if (index_->metric_type == faiss::METRIC_Hamming) { - auto pf_dist = (float*)malloc(p_dist_size); - int32_t* pi_dist = (int32_t*)p_dist; - for (int i = 0; i < elems; i++) { - *(pf_dist + i) = (float)(*(pi_dist + i)); - } - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, pf_dist); - free(p_dist); - } else { - ret_ds->Set(meta::IDS, p_id); - ret_ds->Set(meta::DISTANCE, p_dist); - } + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; } catch (faiss::FaissException& e) { KNOWHERE_THROW_MSG(e.what()); @@ -215,11 +205,10 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances auto params = GenParams(config); auto ivf_index = dynamic_cast(index_.get()); ivf_index->nprobe = params->nprobe; - int32_t* pdistances = (int32_t*)distances; - stdclock::time_point before = stdclock::now(); - // todo: remove static cast (zhiru) - static_cast(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_); + stdclock::time_point before = stdclock::now(); + int32_t* i_distances = reinterpret_cast(distances); + index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_); stdclock::time_point after = stdclock::now(); double search_cost = (std::chrono::duration(after - before)).count(); @@ -228,6 +217,14 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances << ", data search cost: " << faiss::indexIVF_stats.search_time; faiss::indexIVF_stats.quantization_time = 0; faiss::indexIVF_stats.search_time = 0; + + // if hamming, it need transform int32 to float + if (ivf_index->metric_type == faiss::METRIC_Hamming) { + int64_t num = n * k; + for (int64_t i = 0; i < num; i++) { + distances[i] = static_cast(i_distances[i]); + } + } } } // namespace knowhere diff --git a/sdk/examples/binary_vector/src/ClientTest.cpp b/sdk/examples/binary_vector/src/ClientTest.cpp index 6079237529..439c0182a4 100644 --- a/sdk/examples/binary_vector/src/ClientTest.cpp +++ b/sdk/examples/binary_vector/src/ClientTest.cpp @@ -109,7 +109,26 @@ TestProcess(std::shared_ptr connection, TOP_K, NPROBE, search_entity_array, - topk_query_result); + topk_query_result, + milvus::MetricType::HAMMING); + + milvus_sdk::Utils::DoSearch(connection, + collection_param.collection_name, + partition_tags, + TOP_K, + NPROBE, + search_entity_array, + topk_query_result, + milvus::MetricType::SUBSTRUCTURE); + + milvus_sdk::Utils::DoSearch(connection, + collection_param.collection_name, + partition_tags, + TOP_K, + NPROBE, + search_entity_array, + topk_query_result, + milvus::MetricType::SUPERSTRUCTURE); } { // wait unit build index finish @@ -170,41 +189,5 @@ ClientTest::Test(const std::string& address, const std::string& port) { TestProcess(connection, collection_param, index_param); } - { - milvus::CollectionParam collection_param = { - "collection_2", - 512, // dimension - 512, // index file size - milvus::MetricType::SUBSTRUCTURE - }; - - JSON json_params = {}; - milvus::IndexParam index_param = { - collection_param.collection_name, - milvus::IndexType::FLAT, - json_params.dump() - }; - - TestProcess(connection, collection_param, index_param); - } - - { - milvus::CollectionParam collection_param = { - "collection_3", - 128, // dimension - 1024, // index file size - milvus::MetricType::SUPERSTRUCTURE - }; - - JSON json_params = {}; - milvus::IndexParam index_param = { - collection_param.collection_name, - milvus::IndexType::FLAT, - json_params.dump() - }; - - TestProcess(connection, collection_param, index_param); - } - milvus::Connection::Destroy(connection); } diff --git a/sdk/examples/utils/Utils.cpp b/sdk/examples/utils/Utils.cpp index 4428af52a5..aa7fba0358 100644 --- a/sdk/examples/utils/Utils.cpp +++ b/sdk/examples/utils/Utils.cpp @@ -202,7 +202,7 @@ void Utils::DoSearch(std::shared_ptr conn, const std::string& collection_name, const std::vector& partition_tags, int64_t top_k, int64_t nprobe, const std::vector>& entity_array, - milvus::TopKQueryResult& topk_query_result) { + milvus::TopKQueryResult& topk_query_result, milvus::MetricType metric_type) { topk_query_result.clear(); std::vector temp_entity_array; @@ -213,6 +213,10 @@ Utils::DoSearch(std::shared_ptr conn, const std::string& col { BLOCK_SPLITER JSON json_params = {{"nprobe", nprobe}}; + if (metric_type != milvus::MetricType::INVALID) { + json_params["metric_type"] = metric_type; + } + milvus_sdk::TimeRecorder rc("Search"); milvus::Status stat = conn->Search(collection_name, diff --git a/sdk/examples/utils/Utils.h b/sdk/examples/utils/Utils.h index a70e97a956..5f60e33140 100644 --- a/sdk/examples/utils/Utils.h +++ b/sdk/examples/utils/Utils.h @@ -69,7 +69,8 @@ class Utils { DoSearch(std::shared_ptr conn, const std::string& collection_name, const std::vector& partition_tags, int64_t top_k, int64_t nprobe, const std::vector>& entity_array, - milvus::TopKQueryResult& topk_query_result); + milvus::TopKQueryResult& topk_query_result, + milvus::MetricType metric_type = milvus::MetricType::INVALID); static std::vector GenLeafQuery(); diff --git a/sdk/include/MilvusApi.h b/sdk/include/MilvusApi.h index c07e466041..40345cdfeb 100644 --- a/sdk/include/MilvusApi.h +++ b/sdk/include/MilvusApi.h @@ -42,6 +42,7 @@ enum class IndexType { }; enum class MetricType { + INVALID = 0, L2 = 1, // Euclidean Distance IP = 2, // Cosine Similarity HAMMING = 3, // Hamming Distance