diff --git a/build/docker/env/cpu/ubuntu18.04/Dockerfile b/build/docker/env/cpu/ubuntu18.04/Dockerfile index df32dabe73..5c1f53ea92 100644 --- a/build/docker/env/cpu/ubuntu18.04/Dockerfile +++ b/build/docker/env/cpu/ubuntu18.04/Dockerfile @@ -17,10 +17,10 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"] ENV DEBIAN_FRONTEND noninteractive -RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 && \ +RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 clang-format-10 && \ wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \ apt-get update && apt-get install -y --no-install-recommends \ - g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-10 clang-tidy-10 lcov && \ + g++ gcc gfortran git make ccache libssl-dev zlib1g-dev libboost-regex-dev libboost-program-options-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev python3-dev libboost-python-dev libcurl4-openssl-dev libtbb-dev clang-format-7 clang-tidy-7 lcov && \ apt-get remove --purge -y && \ rm -rf /var/lib/apt/lists/* @@ -28,7 +28,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.9.tar.gz && \ tar zxvf v0.3.9.tar.gz && cd OpenBLAS-0.3.9 && \ make TARGET=CORE2 DYNAMIC_ARCH=1 DYNAMIC_OLDER=1 USE_THREAD=0 USE_OPENMP=0 FC=gfortran CC=gcc COMMON_OPT="-O3 -g -fPIC" FCOMMON_OPT="-O3 -g -fPIC -frecursive" NMAX="NUM_THREADS=128" LIBPREFIX="libopenblas" LAPACKE="NO_LAPACKE=1" INTERFACE64=0 NO_STATIC=1 && \ - make PREFIX=/usr NO_STATIC=1 install && \ + make PREFIX=/usr install && \ cd .. && rm -rf OpenBLAS-0.3.9 && rm v0.3.9.tar.gz ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib" diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 82d628cd21..865b4d52a8 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -10,7 +10,8 @@ set(SEGCORE_FILES IndexingEntry.cpp InsertRecord.cpp Reduce.cpp - plan_c.cpp) + plan_c.cpp + reduce_c.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES} ) diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp index af98897969..a6d83cbe3f 100644 --- a/internal/core/src/segcore/Reduce.cpp +++ b/internal/core/src/segcore/Reduce.cpp @@ -1,8 +1,11 @@ #include #include #include + +#include "Reduce.h" + namespace milvus::segcore { -void +Status merge_into(int64_t queries, int64_t topk, float* distances, @@ -37,5 +40,6 @@ merge_into(int64_t queries, std::copy_n(buf_dis.data(), topk, src2_dis); std::copy_n(buf_uids.data(), topk, src2_uids); } + return Status::OK(); } } // namespace milvus::segcore diff --git a/internal/core/src/segcore/Reduce.h b/internal/core/src/segcore/Reduce.h index 9c769c2810..65c6798a54 100644 --- a/internal/core/src/segcore/Reduce.h +++ b/internal/core/src/segcore/Reduce.h @@ -2,8 +2,11 @@ #include #include #include + +#include "utils/Status.h" + namespace milvus::segcore { -void +Status merge_into(int64_t num_queries, int64_t topk, float* distances, diff --git a/internal/core/src/segcore/plan_c.cpp b/internal/core/src/segcore/plan_c.cpp index d4f6d356e7..a3fb3b5f19 100644 --- a/internal/core/src/segcore/plan_c.cpp +++ b/internal/core/src/segcore/plan_c.cpp @@ -20,8 +20,8 @@ ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_ } long int -GetNumOfQueries(CPlaceholderGroup placeholderGroup) { - auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup); +GetNumOfQueries(CPlaceholderGroup placeholder_group) { + auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholder_group); return res; } @@ -41,8 +41,8 @@ DeletePlan(CPlan cPlan) { } void -DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) { - auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup; - delete placeHolderGroup; +DeletePlaceholderGroup(CPlaceholderGroup cPlaceholder_group) { + auto placeHolder_group = (milvus::query::PlaceholderGroup*)cPlaceholder_group; + delete placeHolder_group; std::cout << "delete placeholder" << std::endl; } diff --git a/internal/core/src/segcore/plan_c.h b/internal/core/src/segcore/plan_c.h index d757fa94c7..436d78e996 100644 --- a/internal/core/src/segcore/plan_c.h +++ b/internal/core/src/segcore/plan_c.h @@ -15,7 +15,7 @@ CPlaceholderGroup ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size); long int -GetNumOfQueries(CPlaceholderGroup placeholderGroup); +GetNumOfQueries(CPlaceholderGroup placeholder_group); long int GetTopK(CPlan plan); @@ -24,7 +24,7 @@ void DeletePlan(CPlan plan); void -DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup); +DeletePlaceholderGroup(CPlaceholderGroup placeholder_group); #ifdef __cplusplus } diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp new file mode 100644 index 0000000000..10963bbb1d --- /dev/null +++ b/internal/core/src/segcore/reduce_c.cpp @@ -0,0 +1,9 @@ +#include "reduce_c.h" +#include "Reduce.h" + +int +MergeInto( + long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids) { + auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids); + return status.code(); +} diff --git a/internal/core/src/segcore/reduce_c.h b/internal/core/src/segcore/reduce_c.h new file mode 100644 index 0000000000..862f8bb55a --- /dev/null +++ b/internal/core/src/segcore/reduce_c.h @@ -0,0 +1,13 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#include + +int +MergeInto( + long int num_queries, long int topk, float* distances, long int* uids, float* new_distances, long int* new_uids); + +#ifdef __cplusplus +} +#endif diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 62dff9c310..4b28f4fa03 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -6,6 +6,7 @@ #include "segcore/collection_c.h" #include "segcore/segment_c.h" #include "pb/service_msg.pb.h" +#include "segcore/reduce_c.h" #include namespace chrono = std::chrono; @@ -510,4 +511,33 @@ TEST(CApiTest, GetRowCountTest) { // auto segment = NewSegment(collection, 0); // DeleteCollection(collection); // DeleteSegment(segment); -//} \ No newline at end of file +//} + +TEST(CApiTest, MergeInto) { + std::vector uids; + std::vector distance; + + std::vector new_uids; + std::vector new_distance; + + int64_t num_queries = 1; + int64_t topk = 2; + + uids.push_back(1); + uids.push_back(2); + distance.push_back(5); + distance.push_back(1000); + + new_uids.push_back(3); + new_uids.push_back(4); + new_distance.push_back(2); + new_distance.push_back(6); + + auto res = MergeInto(num_queries, topk, distance.data(), uids.data(), new_distance.data(), new_uids.data()); + + ASSERT_EQ(res, 0); + ASSERT_EQ(uids[0], 3); + ASSERT_EQ(distance[0], 2); + ASSERT_EQ(uids[1], 1); + ASSERT_EQ(distance[1], 5); +} diff --git a/internal/reader/search_service.go b/internal/reader/search_service.go index e66ed8684a..464d7952eb 100644 --- a/internal/reader/search_service.go +++ b/internal/reader/search_service.go @@ -6,7 +6,8 @@ import ( "errors" "fmt" "log" - "sort" + "math" + "sync" "github.com/golang/protobuf/proto" @@ -17,8 +18,11 @@ import ( ) type searchService struct { - ctx context.Context - cancel context.CancelFunc + ctx context.Context + wait sync.WaitGroup + cancel context.CancelFunc + msgBuffer chan msgstream.TsMsg + unsolvedMsg []msgstream.TsMsg replica *collectionReplica tSafeWatcher *tSafeWatcher @@ -29,11 +33,6 @@ type searchService struct { type ResultEntityIds []UniqueID -type SearchResult struct { - ResultIds []UniqueID - ResultDistances []float32 -} - func newSearchService(ctx context.Context, replica *collectionReplica) *searchService { receiveBufSize := Params.searchReceiveBufSize() pulsarBufSize := Params.searchPulsarBufSize() @@ -58,9 +57,13 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe var outputStream msgstream.MsgStream = searchResultStream searchServiceCtx, searchServiceCancel := context.WithCancel(ctx) + msgBuffer := make(chan msgstream.TsMsg, receiveBufSize) + unsolvedMsg := make([]msgstream.TsMsg, 0) return &searchService{ - ctx: searchServiceCtx, - cancel: searchServiceCancel, + ctx: searchServiceCtx, + cancel: searchServiceCancel, + msgBuffer: msgBuffer, + unsolvedMsg: unsolvedMsg, replica: replica, tSafeWatcher: newTSafeWatcher(), @@ -73,27 +76,10 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe func (ss *searchService) start() { (*ss.searchMsgStream).Start() (*ss.searchResultMsgStream).Start() - - go func() { - for { - select { - case <-ss.ctx.Done(): - return - default: - msgPack := (*ss.searchMsgStream).Consume() - if msgPack == nil || len(msgPack.Msgs) <= 0 { - continue - } - // TODO: add serviceTime check - err := ss.search(msgPack.Msgs) - if err != nil { - fmt.Println("search Failed") - ss.publishFailedSearchResult() - } - fmt.Println("Do search done") - } - } - }() + ss.wait.Add(2) + go ss.receiveSearchMsg() + go ss.startSearchService() + ss.wait.Wait() } func (ss *searchService) close() { @@ -114,12 +100,68 @@ func (ss *searchService) waitNewTSafe() Timestamp { return timestamp } -func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { - - type SearchResult struct { - ResultID int64 - ResultDistance float32 +func (ss *searchService) receiveSearchMsg() { + defer ss.wait.Done() + for { + select { + case <-ss.ctx.Done(): + return + default: + msgPack := (*ss.searchMsgStream).Consume() + if msgPack == nil || len(msgPack.Msgs) <= 0 { + continue + } + for i := range msgPack.Msgs { + ss.msgBuffer <- msgPack.Msgs[i] + //fmt.Println("receive a search msg") + } + } } +} + +func (ss *searchService) startSearchService() { + defer ss.wait.Done() + for { + select { + case <-ss.ctx.Done(): + return + default: + serviceTimestamp := (*(*ss.replica).getTSafe()).get() + searchMsg := make([]msgstream.TsMsg, 0) + tempMsg := make([]msgstream.TsMsg, 0) + tempMsg = append(tempMsg, ss.unsolvedMsg...) + ss.unsolvedMsg = ss.unsolvedMsg[:0] + for _, msg := range tempMsg { + if msg.BeginTs() > serviceTimestamp { + searchMsg = append(searchMsg, msg) + continue + } + ss.unsolvedMsg = append(ss.unsolvedMsg, msg) + } + + msgBufferLength := len(ss.msgBuffer) + for i := 0; i < msgBufferLength; i++ { + msg := <-ss.msgBuffer + if msg.BeginTs() > serviceTimestamp { + searchMsg = append(searchMsg, msg) + continue + } + ss.unsolvedMsg = append(ss.unsolvedMsg, msg) + } + if len(searchMsg) <= 0 { + continue + } + err := ss.search(searchMsg) + if err != nil { + fmt.Println("search Failed") + ss.publishFailedSearchResult() + } + fmt.Println("Do search done") + } + } +} + +func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { // TODO:: cache map[dsl]plan // TODO: reBatched search requests for _, msg := range searchMessages { @@ -129,8 +171,6 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { } searchTimestamp := searchMsg.Timestamp - - // TODO:: add serviceable time var queryBlob = searchMsg.Query.Value query := servicepb.Query{} err := proto.Unmarshal(queryBlob, &query) @@ -162,9 +202,11 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { for _, pg := range placeholderGroups { numQueries += pg.GetNumOfQuery() } - var searchResults = make([][]SearchResult, numQueries) - for i := 0; i < int(numQueries); i++ { - searchResults[i] = make([]SearchResult, 0) + + resultIds := make([]IntPrimaryKey, topK*numQueries) + resultDistances := make([]float32, topK*numQueries) + for i := range resultDistances { + resultDistances[i] = math.MaxFloat32 } // 3. Do search in all segments @@ -174,42 +216,27 @@ func (ss *searchService) search(searchMessages []msgstream.TsMsg) error { return err } for _, segment := range partition.segments { - res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) + err := segment.segmentSearch(plan, + placeholderGroups, + []Timestamp{searchTimestamp}, + resultIds, + resultDistances, + numQueries, + topK) if err != nil { return err } - for i := 0; int64(i) < numQueries; i++ { - for j := int64(i) * topK; j < int64(i+1)*topK; j++ { - searchResults[i] = append(searchResults[i], SearchResult{ - ResultID: res.ResultIds[j], - ResultDistance: res.ResultDistances[j], - }) - } - } - } - } - - // 4. Reduce results - // TODO::reduce in c++ merge_into func - for _, temp := range searchResults { - sort.Slice(temp, func(i, j int) bool { - return temp[i].ResultDistance < temp[j].ResultDistance - }) - } - - for i, tmp := range searchResults { - if int64(len(tmp)) > topK { - searchResults[i] = searchResults[i][:topK] } } + // 4. return results hits := make([]*servicepb.Hits, 0) - for _, value := range searchResults { + for i := int64(0); i < numQueries; i++ { hit := servicepb.Hits{} score := servicepb.Score{} - for j := 0; int64(j) < topK; j++ { - hit.IDs = append(hit.IDs, value[j].ResultID) - score.Values = append(score.Values, value[j].ResultDistance) + for j := i * topK; j < (i+1)*topK; j++ { + hit.IDs = append(hit.IDs, resultIds[j]) + score.Values = append(score.Values, resultDistances[j]) } hit.Scores = append(hit.Scores, &score) hits = append(hits, &hit) diff --git a/internal/reader/search_service_test.go b/internal/reader/search_service_test.go index 5a26a7e8d1..8f1cdd01a4 100644 --- a/internal/reader/search_service_test.go +++ b/internal/reader/search_service_test.go @@ -175,8 +175,9 @@ func TestSearch_Search(t *testing.T) { searchStream.SetPulsarClient(pulsarURL) searchStream.CreatePulsarProducers(searchProducerChannels) + var vecSearch = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17} var searchRawData []byte - for _, ele := range vec { + for _, ele := range vecSearch { buf := make([]byte, 4) binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) searchRawData = append(searchRawData, buf...) diff --git a/internal/reader/segment.go b/internal/reader/segment.go index 81e30b3601..eabd69daf4 100644 --- a/internal/reader/segment.go +++ b/internal/reader/segment.go @@ -9,6 +9,7 @@ package reader #include "collection_c.h" #include "segment_c.h" #include "plan_c.h" +#include "reduce_c.h" */ import "C" @@ -178,14 +179,24 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps return nil } -func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) { +func (s *Segment) segmentSearch(plan *Plan, + placeHolderGroups []*PlaceholderGroup, + timestamp []Timestamp, + resultIds []IntPrimaryKey, + resultDistances []float32, + numQueries int64, + topK int64) error { /* - void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids, - float* result_distances) + void* Search(void* plan, + void* placeholder_groups, + uint64_t* timestamps, + int num_groups, + long int* result_ids, + float* result_distances); */ - resultIds := make([]IntPrimaryKey, topK*numQueries) - resultDistances := make([]float32, topK*numQueries) + newResultIds := make([]IntPrimaryKey, topK*numQueries) + NewResultDistances := make([]float32, topK*numQueries) cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) for _, pg := range placeHolderGroups { cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) @@ -194,16 +205,22 @@ func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGrou var cTimestamp = (*C.ulong)(×tamp[0]) var cResultIds = (*C.long)(&resultIds[0]) var cResultDistances = (*C.float)(&resultDistances[0]) + var cNewResultIds = (*C.long)(&newResultIds[0]) + var cNewResultDistances = (*C.float)(&NewResultDistances[0]) var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) var cNumGroups = C.int(len(placeHolderGroups)) - var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances) - + var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cNewResultIds, cNewResultDistances) if status != 0 { - return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status))) + return errors.New("search failed, error code = " + strconv.Itoa(int(status))) } - //fmt.Println("search Result---- Ids =", resultIds, ", Distances =", resultDistances) - - return &SearchResult{ResultIds: resultIds, ResultDistances: resultDistances}, nil + cNumQueries := C.long(numQueries) + cTopK := C.long(topK) + // reduce search result + status = C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds) + if status != 0 { + return errors.New("merge search result failed, error code = " + strconv.Itoa(int(status))) + } + return nil } diff --git a/internal/reader/segment_test.go b/internal/reader/segment_test.go index 961ca247e9..c76641d692 100644 --- a/internal/reader/segment_test.go +++ b/internal/reader/segment_test.go @@ -661,8 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) { for _, pg := range placeholderGroups { numQueries += pg.GetNumOfQuery() } + resultIds := make([]IntPrimaryKey, topK*numQueries) + resultDistances := make([]float32, topK*numQueries) + for i := range resultDistances { + resultDistances[i] = math.MaxFloat32 + } - _, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) + err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, resultIds, resultDistances, numQueries, topK) assert.NoError(t, err) }