Get search field id from search plan and log if loaded index when search segments (#18183)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2022-07-08 20:18:22 +08:00 committed by GitHub
parent d70a2c8796
commit 22508f36d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 5 deletions

View File

@ -78,6 +78,11 @@ GetTopK(const Plan* plan) {
return plan->plan_node_->search_info_.topk_; return plan->plan_node_->search_info_.topk_;
} }
int64_t
GetFieldID(const Plan* plan) {
return plan->plan_node_->search_info_.field_id_.get();
}
int64_t int64_t
GetNumOfQueries(const PlaceholderGroup* group) { GetNumOfQueries(const PlaceholderGroup* group) {
return group->at(0).num_of_queries_; return group->at(0).num_of_queries_;

View File

@ -48,4 +48,7 @@ CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan,
int64_t int64_t
GetTopK(const Plan*); GetTopK(const Plan*);
int64_t
GetFieldID(const Plan* plan);
} // namespace milvus::query } // namespace milvus::query

View File

@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License // or implied. See the License for the specific language governing permissions and limitations under the License
#include "common/CGoHelper.h"
#include "pb/segcore.pb.h" #include "pb/segcore.pb.h"
#include "query/Plan.h" #include "query/Plan.h"
#include "segcore/Collection.h" #include "segcore/Collection.h"
@ -109,6 +110,17 @@ GetTopK(CSearchPlan plan) {
return res; return res;
} }
CStatus
GetFieldID(CSearchPlan plan, int64_t* field_id) {
try {
auto p = static_cast<const milvus::query::Plan*>(plan);
*field_id = milvus::query::GetFieldID(p);
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, strdup(e.what()));
}
}
const char* const char*
GetMetricType(CSearchPlan plan) { GetMetricType(CSearchPlan plan) {
auto search_plan = static_cast<milvus::query::Plan*>(plan); auto search_plan = static_cast<milvus::query::Plan*>(plan);

View File

@ -42,6 +42,9 @@ GetNumOfQueries(CPlaceholderGroup placeholder_group);
int64_t int64_t
GetTopK(CSearchPlan plan); GetTopK(CSearchPlan plan);
CStatus
GetFieldID(CSearchPlan plan, int64_t* field_id);
const char* const char*
GetMetricType(CSearchPlan plan); GetMetricType(CSearchPlan plan);

View File

@ -208,6 +208,46 @@ TEST(CApiTest, SegmentTest) {
DeleteSegment(segment); DeleteSegment(segment);
} }
TEST(CApiTest, CPlan) {
std::string schema_string = generate_collection_schema("JACCARD", DIM, true);
auto collection = NewCollection(schema_string.c_str());
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
void* plan = nullptr;
auto status = CreateSearchPlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
int64_t field_id = -1;
status = GetFieldID(plan, &field_id);
assert(status.error_code == Success);
auto col = static_cast<Collection*>(collection);
for (auto& [target_field_id, field_meta] : col->get_schema()->get_fields()) {
if (field_meta.is_vector()) {
assert(field_id == target_field_id.get());
}
}
assert(field_id != -1);
DeleteSearchPlan(plan);
}
template <typename Message> template <typename Message>
std::vector<uint8_t> std::vector<uint8_t>
serialize(const Message* msg) { serialize(const Message* msg) {
@ -1104,7 +1144,7 @@ TEST(CApiTest, ReudceNullResult) {
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(),
slice_nqs.data(), slice_topKs.data(), slice_nqs.size()); slice_nqs.data(), slice_topKs.data(), slice_nqs.size());
assert(status.error_code == Success); assert(status.error_code == Success);
auto search_result = (SearchResult*)results[0]; auto search_result = (SearchResult*)results[0];
auto size = search_result->result_offsets_.size(); auto size = search_result->result_offsets_.size();
EXPECT_EQ(size, num_queries / 2); EXPECT_EQ(size, num_queries / 2);

View File

@ -95,6 +95,7 @@ type searchRequest struct {
cPlaceholderGroup C.CPlaceholderGroup cPlaceholderGroup C.CPlaceholderGroup
timestamp Timestamp timestamp Timestamp
msgID UniqueID msgID UniqueID
searchFieldID UniqueID
} }
func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*searchRequest, error) { func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*searchRequest, error) {
@ -129,11 +130,19 @@ func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh
return nil, err return nil, err
} }
var fieldID C.int64_t
status = C.GetFieldID(plan.cSearchPlan, &fieldID)
if err = HandleCStatus(&status, "get fieldID from plan failed"); err != nil {
plan.delete()
return nil, err
}
ret := &searchRequest{ ret := &searchRequest{
plan: plan, plan: plan,
cPlaceholderGroup: cPlaceholderGroup, cPlaceholderGroup: cPlaceholderGroup,
timestamp: req.Req.GetTravelTimestamp(), timestamp: req.Req.GetTravelTimestamp(),
msgID: req.GetReq().GetBase().GetMsgID(), msgID: req.GetReq().GetBase().GetMsgID(),
searchFieldID: int64(fieldID),
} }
return ret, nil return ret, nil

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
) )
func TestPlan_Plan(t *testing.T) { func TestPlan_Plan(t *testing.T) {
@ -124,3 +125,23 @@ func TestPlan_PlaceholderGroup(t *testing.T) {
holder.delete() holder.delete()
deleteCollection(collection) deleteCollection(collection)
} }
func TestPlan_newSearchRequest(t *testing.T) {
iReq, _ := genSearchRequest(defaultNQ, IndexHNSW, genTestCollectionSchema())
collection := newCollection(defaultCollectionID, genTestCollectionSchema())
req := &querypb.SearchRequest{
Req: iReq,
DmlChannels: []string{defaultDMLChannel},
SegmentIDs: []UniqueID{defaultSegmentID},
FromShardLeader: true,
Scope: querypb.DataScope_Historical,
}
searchReq, err := newSearchRequest(collection, req, req.Req.GetPlaceholderGroup())
assert.NoError(t, err)
assert.Equal(t, simpleFloatVecField.id, searchReq.searchFieldID)
assert.EqualValues(t, defaultNQ, searchReq.getNumOfQuery())
searchReq.delete()
deleteCollection(collection)
}

View File

@ -279,9 +279,13 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
return nil, fmt.Errorf("nil search plan") return nil, fmt.Errorf("nil search plan")
} }
loadIndex := s.hasLoadIndexForIndexedField(searchReq.searchFieldID)
var searchResult SearchResult var searchResult SearchResult
log.Debug("start do search on segment", zap.Int64("msgID", searchReq.msgID), log.Debug("start do search on segment",
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) zap.Int64("msgID", searchReq.msgID),
zap.Int64("segmentID", s.segmentID),
zap.String("segmentType", s.segmentType.String()),
zap.Bool("loadIndex", loadIndex))
tr := timerecord.NewTimeRecorder("cgoSearch") tr := timerecord.NewTimeRecorder("cgoSearch")
status := C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup, status := C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID)) C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID))
@ -289,8 +293,11 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
if err := HandleCStatus(&status, "Search failed"); err != nil { if err := HandleCStatus(&status, "Search failed"); err != nil {
return nil, err return nil, err
} }
log.Debug("do search on segment done", zap.Int64("msgID", searchReq.msgID), log.Debug("do search on segment done",
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String())) zap.Int64("msgID", searchReq.msgID),
zap.Int64("segmentID", s.segmentID),
zap.String("segmentType", s.segmentType.String()),
zap.Bool("loadIndex", loadIndex))
return &searchResult, nil return &searchResult, nil
} }