From 2474d06a102597c4a88c3c203d22bf44371822bd Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Wed, 16 Sep 2020 18:19:29 +0800 Subject: [PATCH] Add index building of segment Signed-off-by: bigsheeper --- core/src/dog_segment/SegmentNaive.cpp | 106 +++++++++++------- core/src/dog_segment/SegmentNaive.h | 22 ++-- core/src/index/thirdparty/faiss/utils/Heap.h | 6 +- .../thirdparty/faiss/utils/distances.cpp | 21 ++-- .../index/thirdparty/faiss/utils/distances.h | 3 +- core/unittest/test_c_api.cpp | 2 +- go.mod | 2 +- pkg/master/grpc/message/message.proto | 4 +- reader/index.go | 28 ++++- reader/segment.go | 7 +- reader/segment_test.go | 4 +- 11 files changed, 130 insertions(+), 75 deletions(-) diff --git a/core/src/dog_segment/SegmentNaive.cpp b/core/src/dog_segment/SegmentNaive.cpp index 08327eebe7..e7bb86c62f 100644 --- a/core/src/dog_segment/SegmentNaive.cpp +++ b/core/src/dog_segment/SegmentNaive.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace milvus::dog_segment { @@ -67,7 +68,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times int64_t insert_barrier, bool force) -> std::shared_ptr { auto old = deleted_record_.get_lru_entry(); - if(!force || old->bitmap_ptr->capacity() == insert_barrier) { + if (!force || old->bitmap_ptr->capacity() == insert_barrier) { if (old->del_barrier == del_barrier) { return old; } @@ -286,15 +287,17 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe auto bitmap = bitmap_holder->bitmap_ptr; auto topK = query_info->topK; auto num_queries = query_info->num_queries; - auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_[0]); + auto the_offset_opt = schema_->get_offset(query_info->field_name); + assert(the_offset_opt.has_value()); + auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); auto index_entry = index_meta_->lookup_by_field(query_info->field_name); auto conf = index_entry.config; conf[milvus::knowhere::meta::TOPK] = query_info->topK; { auto count = 0; - for(int i = 0; i < bitmap->capacity(); ++i) { - if(bitmap->test(i)) { + for (int i = 0; i < bitmap->capacity(); ++i) { + if (bitmap->test(i)) { ++count; } } @@ -306,8 +309,8 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe auto ds = knowhere::GenDataset(query_info->num_queries, dim, query_info->query_raw_data.data()); auto final = indexing->Query(ds, conf); - auto ids = final->Get(knowhere::meta::IDS); - auto distances = final->Get(knowhere::meta::DISTANCE); + auto ids = final->Get(knowhere::meta::IDS); + auto distances = final->Get(knowhere::meta::DISTANCE); auto total_num = num_queries * topK; result.result_ids_.resize(total_num); @@ -320,38 +323,38 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe std::copy_n(ids, total_num, result.result_ids_.data()); std::copy_n(distances, total_num, result.result_distances_.data()); - for(auto& id: result.result_ids_) { + for (auto &id: result.result_ids_) { id = record_.uids_[id]; } return Status::OK(); } +Status +SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) { + auto ins_barrier = get_barrier(record_, timestamp); + auto del_barrier = get_barrier(deleted_record_, timestamp); + auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); + assert(bitmap_holder); + + auto &field = schema_->operator[](query_info->field_name); + assert(field.get_data_type() == DataType::VECTOR_FLOAT); + auto dim = field.get_dim(); + auto bitmap = bitmap_holder->bitmap_ptr; + auto topK = query_info->topK; + auto num_queries = query_info->num_queries; + // TODO: optimize + + auto the_offset_opt = schema_->get_offset(query_info->field_name); + assert(the_offset_opt.has_value()); + auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); + std::vector>> records(num_queries); + +} Status -SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { - // TODO: enable delete - // TODO: enable index - // TODO: remove mock - if (query_info == nullptr) { - query_info = std::make_shared(); - query_info->field_name = "fakevec"; - query_info->topK = 10; - query_info->num_queries = 1; - - auto dim = schema_->operator[]("fakevec").get_dim(); - std::default_random_engine e(42); - std::uniform_real_distribution<> dis(0.0, 1.0); - query_info->query_raw_data.resize(query_info->num_queries * dim); - for (auto &x: query_info->query_raw_data) { - x = dis(e); - } - } - - if(index_ready_) { - return QueryImpl(query_info, timestamp, result); - } +SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { auto ins_barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); @@ -364,6 +367,11 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult auto bitmap = bitmap_holder->bitmap_ptr; auto topK = query_info->topK; auto num_queries = query_info->num_queries; + // TODO: optimize + auto the_offset_opt = schema_->get_offset(query_info->field_name); + assert(the_offset_opt.has_value()); + auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); + std::vector>> records(num_queries); auto get_L2_distance = [dim](const float *a, const float *b) { float L2_distance = 0; @@ -374,10 +382,6 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult return L2_distance; }; - - // TODO: optimize - std::vector>> records(num_queries); - auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_[0]); for (int64_t i = 0; i < ins_barrier; ++i) { if (i < bitmap->capacity() && bitmap->test(i)) { continue; @@ -417,15 +421,33 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult } return Status::OK(); -// find end of binary -// throw std::runtime_error("unimplemented"); -// auto record_ptr = GetMutableRecord(); -// if (record_ptr) { -// return QueryImpl(*record_ptr, query, timestamp, result); -// } else { -// assert(ready_immutable_); -// return QueryImpl(*record_immutable_, query, timestamp, result); -// } +} + +Status +SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { + // TODO: enable delete + // TODO: enable index + // TODO: remove mock + if (query_info == nullptr) { + query_info = std::make_shared(); + query_info->field_name = "fakevec"; + query_info->topK = 10; + query_info->num_queries = 1; + + auto dim = schema_->operator[]("fakevec").get_dim(); + std::default_random_engine e(42); + std::uniform_real_distribution<> dis(0.0, 1.0); + query_info->query_raw_data.resize(query_info->num_queries * dim); + for (auto &x: query_info->query_raw_data) { + x = dis(e); + } + } + + if (index_ready_) { + return QueryImpl(query_info, timestamp, result); + } else { + return QuerySlowImpl(query_info, timestamp, result); + } } Status diff --git a/core/src/dog_segment/SegmentNaive.h b/core/src/dog_segment/SegmentNaive.h index b9b62d4389..308ed2c6f6 100644 --- a/core/src/dog_segment/SegmentNaive.h +++ b/core/src/dog_segment/SegmentNaive.h @@ -116,14 +116,14 @@ public: } private: - struct MutableRecord { - ConcurrentVector uids_; - tbb::concurrent_vector timestamps_; - std::vector> entity_vecs_; - - MutableRecord(int entity_size) : entity_vecs_(entity_size) { - } - }; +// struct MutableRecord { +// ConcurrentVector uids_; +// tbb::concurrent_vector timestamps_; +// std::vector> entity_vecs_; +// +// MutableRecord(int entity_size) : entity_vecs_(entity_size) { +// } +// }; struct Record { std::atomic reserved = 0; @@ -147,6 +147,12 @@ private: Status QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + Status + QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + + Status + QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + template knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry &entry); diff --git a/core/src/index/thirdparty/faiss/utils/Heap.h b/core/src/index/thirdparty/faiss/utils/Heap.h index 9962cbc112..b37cfedd91 100644 --- a/core/src/index/thirdparty/faiss/utils/Heap.h +++ b/core/src/index/thirdparty/faiss/utils/Heap.h @@ -16,10 +16,7 @@ * small. More complex functions are implemented in Heaps.cpp * */ - - -#ifndef FAISS_Heap_h -#define FAISS_Heap_h +#pragma once #include #include @@ -540,4 +537,3 @@ void indirect_heap_push (size_t k, } // namespace faiss -#endif /* FAISS_Heap_h */ diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp index e97e873614..ef31d36228 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -252,7 +252,7 @@ static void knn_L2sqr_sse ( const float * y, size_t d, size_t nx, size_t ny, float_maxheap_array_t * res, - ConcurrentBitsetPtr bitset = nullptr) + ConcurrentBitsetPtr bitset_base, uint64_t offset) { size_t k = res->k; size_t thread_max_num = omp_get_max_threads(); @@ -279,7 +279,7 @@ static void knn_L2sqr_sse ( #pragma omp parallel for schedule(static) for (size_t j = 0; j < ny; j++) { - if(!bitset || !bitset->test(j)) { + if(!bitset_base || !bitset_base->test(j + offset)) { size_t thread_no = omp_get_thread_num(); const float *y_j = y + j * d; const float *x_i = x + x_from * d; @@ -343,7 +343,7 @@ static void knn_L2sqr_sse ( } for (size_t j = 0; j < ny; j++) { - if (!bitset || !bitset->test(j)) { + if (!bitset_base || !bitset_base->test(j + offset)) { float disij = fvec_L2sqr (x_i, y_j, d); if (disij < val_[0]) { maxheap_swap_top (k, val_, ids_, disij, j); @@ -427,7 +427,7 @@ static void knn_L2sqr_blas (const float * x, size_t d, size_t nx, size_t ny, float_maxheap_array_t * res, const DistanceCorrection &corr, - ConcurrentBitsetPtr bitset = nullptr) + ConcurrentBitsetPtr bitset_base, int64_t offset) { res->heapify (); @@ -473,7 +473,7 @@ static void knn_L2sqr_blas (const float * x, const float *ip_line = ip_block + (i - i0) * (j1 - j0); for (size_t j = j0; j < j1; j++) { - if(!bitset || !bitset->test(j)){ + if(!bitset_base || !bitset_base->test(j + offset)){ float ip = *ip_line; float dis = x_norms[i] + y_norms[j] - 2 * ip; @@ -609,13 +609,16 @@ void knn_L2sqr (const float * x, const float * y, size_t d, size_t nx, size_t ny, float_maxheap_array_t * res, - ConcurrentBitsetPtr bitset) + ConcurrentBitsetPtr bitset, + int64_t offset + ) + { if (nx < distance_compute_blas_threshold) { - knn_L2sqr_sse (x, y, d, nx, ny, res, bitset); + knn_L2sqr_sse (x, y, d, nx, ny, res, bitset, offset); } else { NopDistanceCorrection nop; - knn_L2sqr_blas (x, y, d, nx, ny, res, nop, bitset); + knn_L2sqr_blas (x, y, d, nx, ny, res, nop, bitset, offset); } } @@ -649,7 +652,7 @@ void knn_L2sqr_base_shift ( const float *base_shift) { BaseShiftDistanceCorrection corr = {base_shift}; - knn_L2sqr_blas (x, y, d, nx, ny, res, corr); + knn_L2sqr_blas (x, y, d, nx, ny, res, corr, nullptr, 0); } diff --git a/core/src/index/thirdparty/faiss/utils/distances.h b/core/src/index/thirdparty/faiss/utils/distances.h index b4311d09c6..b5774eef48 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.h +++ b/core/src/index/thirdparty/faiss/utils/distances.h @@ -13,6 +13,7 @@ #pragma once #include +#include #include #include @@ -181,7 +182,7 @@ void knn_L2sqr ( const float * y, size_t d, size_t nx, size_t ny, float_maxheap_array_t * res, - ConcurrentBitsetPtr bitset = nullptr); + ConcurrentBitsetPtr bitset = nullptr, int64_t offset = 0); void knn_jaccard ( const float * x, diff --git a/core/unittest/test_c_api.cpp b/core/unittest/test_c_api.cpp index c62a3549d9..981b1c0101 100644 --- a/core/unittest/test_c_api.cpp +++ b/core/unittest/test_c_api.cpp @@ -211,7 +211,7 @@ auto generate_data(int N) { } -TEST(CApiTest, TestQuery) { +TEST(CApiTest, TestSearchWithIndex) { auto collection_name = "collection0"; auto schema_tmp_conf = "null_schema"; auto collection = NewCollection(collection_name, schema_tmp_conf); diff --git a/go.mod b/go.mod index bad3b5bc54..1fdf9adafe 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/dvsekhvalnov/jose2go v0.0.0-20200901110807-248326c1351b // indirect github.com/frankban/quicktest v1.10.2 // indirect github.com/gogo/protobuf v1.3.1 - github.com/golang/protobuf v1.4.2 + github.com/golang/protobuf v1.3.3 github.com/google/btree v1.0.0 github.com/json-iterator/go v1.1.10 github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d // indirect diff --git a/pkg/master/grpc/message/message.proto b/pkg/master/grpc/message/message.proto index 4bea26acc5..eaa5e4c49a 100644 --- a/pkg/master/grpc/message/message.proto +++ b/pkg/master/grpc/message/message.proto @@ -691,11 +691,13 @@ message InsertOrDeleteMsg { message SearchMsg { string collection_name = 1; VectorRowRecord records = 2; - string partition_tag = 3; + repeated string partition_tag = 3; int64 uid = 4; uint64 timestamp =5; int64 client_id = 6; repeated KeyValuePair extra_params = 7; + repeated string json = 8; + string dsl = 9; } enum SyncType { diff --git a/reader/index.go b/reader/index.go index 28d376c040..b3b286561c 100644 --- a/reader/index.go +++ b/reader/index.go @@ -1,15 +1,37 @@ package reader +/* + +#cgo CFLAGS: -I../core/include + +#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib + +#include "collection_c.h" +#include "partition_c.h" +#include "segment_c.h" + +*/ +import "C" import ( msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" ) -type IndexConfig struct {} +type IndexConfig struct{} -func buildIndex(config IndexConfig) msgPb.Status { +func (s *Segment) buildIndex() msgPb.Status { + /*C.BuildIndex + int + BuildIndex(CSegmentBase c_segment); + */ + var status = C.BuildIndex(s.SegmentPtr) + if status != 0 { + return msgPb.Status{ErrorCode: msgPb.ErrorCode_BUILD_INDEX_ERROR} + } return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} } -func dropIndex(fieldName string) msgPb.Status { +func (s *Segment) dropIndex(fieldName string) msgPb.Status { + // WARN: Not support yet + return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} } diff --git a/reader/segment.go b/reader/segment.go index 273212aea9..02d252b265 100644 --- a/reader/segment.go +++ b/reader/segment.go @@ -15,7 +15,7 @@ import "C" import ( "fmt" "github.com/czs007/suvlim/errors" - schema "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" "strconv" "unsafe" ) @@ -74,6 +74,9 @@ func (s *Segment) Close() error { if status != 0 { return errors.New("Close segment failed, error code = " + strconv.Itoa(int(status))) } + + // Build index after closing segment + s.buildIndex() return nil } @@ -169,7 +172,7 @@ func (s *Segment) SegmentDelete(offset int64, entityIDs *[]int64, timestamps *[] return nil } -func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *schema.VectorRowRecord) (*SearchResult, error) { +func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) { /*C.Search int Search(CSegmentBase c_segment, diff --git a/reader/segment_test.go b/reader/segment_test.go index 170b512f49..00f8bf3008 100644 --- a/reader/segment_test.go +++ b/reader/segment_test.go @@ -3,7 +3,7 @@ package reader import ( "encoding/binary" "fmt" - schema "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" "github.com/stretchr/testify/assert" "math" "testing" @@ -137,7 +137,7 @@ func TestSegment_SegmentSearch(t *testing.T) { for i := 0; i < 16; i ++ { queryRawData = append(queryRawData, float32(i)) } - var vectorRecord = schema.VectorRowRecord { + var vectorRecord = msgPb.VectorRowRecord { FloatData: queryRawData, } var searchRes, searchErr = segment.SegmentSearch(queryJson, timestamps[0], &vectorRecord)