diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 43878cb8c4..b34b677da9 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -78,6 +78,11 @@ GetTopK(const Plan* plan) { return plan->plan_node_->search_info_.topk_; } +int64_t +GetFieldID(const Plan* plan) { + return plan->plan_node_->search_info_.field_id_.get(); +} + int64_t GetNumOfQueries(const PlaceholderGroup* group) { return group->at(0).num_of_queries_; diff --git a/internal/core/src/query/Plan.h b/internal/core/src/query/Plan.h index 68848385a6..ba4eacc15a 100644 --- a/internal/core/src/query/Plan.h +++ b/internal/core/src/query/Plan.h @@ -48,4 +48,7 @@ CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, int64_t GetTopK(const Plan*); +int64_t +GetFieldID(const Plan* plan); + } // namespace milvus::query diff --git a/internal/core/src/segcore/plan_c.cpp b/internal/core/src/segcore/plan_c.cpp index 1301676c4f..cb2dbaaf00 100644 --- a/internal/core/src/segcore/plan_c.cpp +++ b/internal/core/src/segcore/plan_c.cpp @@ -9,6 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include "common/CGoHelper.h" #include "pb/segcore.pb.h" #include "query/Plan.h" #include "segcore/Collection.h" @@ -109,6 +110,17 @@ GetTopK(CSearchPlan plan) { return res; } +CStatus +GetFieldID(CSearchPlan plan, int64_t* field_id) { + try { + auto p = static_cast(plan); + *field_id = milvus::query::GetFieldID(p); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(UnexpectedError, strdup(e.what())); + } +} + const char* GetMetricType(CSearchPlan plan) { auto search_plan = static_cast(plan); diff --git a/internal/core/src/segcore/plan_c.h b/internal/core/src/segcore/plan_c.h index 30dc4df42a..c56ac7bbec 100644 --- a/internal/core/src/segcore/plan_c.h +++ b/internal/core/src/segcore/plan_c.h @@ -42,6 +42,9 @@ GetNumOfQueries(CPlaceholderGroup placeholder_group); int64_t GetTopK(CSearchPlan plan); +CStatus +GetFieldID(CSearchPlan plan, int64_t* field_id); + const char* GetMetricType(CSearchPlan plan); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index c354a9f087..2f9572e187 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -208,6 +208,46 @@ TEST(CApiTest, SegmentTest) { DeleteSegment(segment); } +TEST(CApiTest, CPlan) { + std::string schema_string = generate_collection_schema("JACCARD", DIM, true); + auto collection = NewCollection(schema_string.c_str()); + + const char* dsl_string = R"( + { + "bool": { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10, + "round_decimal": 3 + } + } + } + })"; + + void* plan = nullptr; + auto status = CreateSearchPlan(collection, dsl_string, &plan); + assert(status.error_code == Success); + + int64_t field_id = -1; + status = GetFieldID(plan, &field_id); + assert(status.error_code == Success); + + auto col = static_cast(collection); + for (auto& [target_field_id, field_meta] : col->get_schema()->get_fields()) { + if (field_meta.is_vector()) { + assert(field_id == target_field_id.get()); + } + } + assert(field_id != -1); + + DeleteSearchPlan(plan); +} + template std::vector serialize(const Message* msg) { @@ -1104,7 +1144,7 @@ TEST(CApiTest, ReudceNullResult) { status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(), slice_topKs.data(), slice_nqs.size()); assert(status.error_code == Success); - + auto search_result = (SearchResult*)results[0]; auto size = search_result->result_offsets_.size(); EXPECT_EQ(size, num_queries / 2); diff --git a/internal/querynode/plan.go b/internal/querynode/plan.go index 5dbf406f58..3a63c08ceb 100644 --- a/internal/querynode/plan.go +++ b/internal/querynode/plan.go @@ -95,6 +95,7 @@ type searchRequest struct { cPlaceholderGroup C.CPlaceholderGroup timestamp Timestamp msgID UniqueID + searchFieldID UniqueID } func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*searchRequest, error) { @@ -129,11 +130,19 @@ func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh return nil, err } + var fieldID C.int64_t + status = C.GetFieldID(plan.cSearchPlan, &fieldID) + if err = HandleCStatus(&status, "get fieldID from plan failed"); err != nil { + plan.delete() + return nil, err + } + ret := &searchRequest{ plan: plan, cPlaceholderGroup: cPlaceholderGroup, timestamp: req.Req.GetTravelTimestamp(), msgID: req.GetReq().GetBase().GetMsgID(), + searchFieldID: int64(fieldID), } return ret, nil diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 0549afeb9b..94a386951b 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" ) func TestPlan_Plan(t *testing.T) { @@ -124,3 +125,23 @@ func TestPlan_PlaceholderGroup(t *testing.T) { holder.delete() deleteCollection(collection) } + +func TestPlan_newSearchRequest(t *testing.T) { + iReq, _ := genSearchRequest(defaultNQ, IndexHNSW, genTestCollectionSchema()) + collection := newCollection(defaultCollectionID, genTestCollectionSchema()) + req := &querypb.SearchRequest{ + Req: iReq, + DmlChannels: []string{defaultDMLChannel}, + SegmentIDs: []UniqueID{defaultSegmentID}, + FromShardLeader: true, + Scope: querypb.DataScope_Historical, + } + searchReq, err := newSearchRequest(collection, req, req.Req.GetPlaceholderGroup()) + assert.NoError(t, err) + + assert.Equal(t, simpleFloatVecField.id, searchReq.searchFieldID) + assert.EqualValues(t, defaultNQ, searchReq.getNumOfQuery()) + + searchReq.delete() + deleteCollection(collection) +} diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 8095a932a8..9028a0e3e2 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -279,9 +279,13 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) { return nil, fmt.Errorf("nil search plan") } + loadIndex := s.hasLoadIndexForIndexedField(searchReq.searchFieldID) var searchResult SearchResult - log.Debug("start do search on segment", zap.Int64("msgID", searchReq.msgID), - zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) + log.Debug("start do search on segment", + zap.Int64("msgID", searchReq.msgID), + zap.Int64("segmentID", s.segmentID), + zap.String("segmentType", s.segmentType.String()), + zap.Bool("loadIndex", loadIndex)) tr := timerecord.NewTimeRecorder("cgoSearch") status := C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup, C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID)) @@ -289,8 +293,11 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) { if err := HandleCStatus(&status, "Search failed"); err != nil { return nil, err } - log.Debug("do search on segment done", zap.Int64("msgID", searchReq.msgID), - zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) + log.Debug("do search on segment done", + zap.Int64("msgID", searchReq.msgID), + zap.Int64("segmentID", s.segmentID), + zap.String("segmentType", s.segmentType.String()), + zap.Bool("loadIndex", loadIndex)) return &searchResult, nil }