mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Get search field id from search plan and log if loaded index when search segments (#18183)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
parent
d70a2c8796
commit
22508f36d3
@ -78,6 +78,11 @@ GetTopK(const Plan* plan) {
|
||||
return plan->plan_node_->search_info_.topk_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetFieldID(const Plan* plan) {
|
||||
return plan->plan_node_->search_info_.field_id_.get();
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetNumOfQueries(const PlaceholderGroup* group) {
|
||||
return group->at(0).num_of_queries_;
|
||||
|
||||
@ -48,4 +48,7 @@ CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan,
|
||||
int64_t
|
||||
GetTopK(const Plan*);
|
||||
|
||||
int64_t
|
||||
GetFieldID(const Plan* plan);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "common/CGoHelper.h"
|
||||
#include "pb/segcore.pb.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/Collection.h"
|
||||
@ -109,6 +110,17 @@ GetTopK(CSearchPlan plan) {
|
||||
return res;
|
||||
}
|
||||
|
||||
CStatus
|
||||
GetFieldID(CSearchPlan plan, int64_t* field_id) {
|
||||
try {
|
||||
auto p = static_cast<const milvus::query::Plan*>(plan);
|
||||
*field_id = milvus::query::GetFieldID(p);
|
||||
return milvus::SuccessCStatus();
|
||||
} catch (std::exception& e) {
|
||||
return milvus::FailureCStatus(UnexpectedError, strdup(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
const char*
|
||||
GetMetricType(CSearchPlan plan) {
|
||||
auto search_plan = static_cast<milvus::query::Plan*>(plan);
|
||||
|
||||
@ -42,6 +42,9 @@ GetNumOfQueries(CPlaceholderGroup placeholder_group);
|
||||
int64_t
|
||||
GetTopK(CSearchPlan plan);
|
||||
|
||||
CStatus
|
||||
GetFieldID(CSearchPlan plan, int64_t* field_id);
|
||||
|
||||
const char*
|
||||
GetMetricType(CSearchPlan plan);
|
||||
|
||||
|
||||
@ -208,6 +208,46 @@ TEST(CApiTest, SegmentTest) {
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, CPlan) {
|
||||
std::string schema_string = generate_collection_schema("JACCARD", DIM, true);
|
||||
auto collection = NewCollection(schema_string.c_str());
|
||||
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
void* plan = nullptr;
|
||||
auto status = CreateSearchPlan(collection, dsl_string, &plan);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
int64_t field_id = -1;
|
||||
status = GetFieldID(plan, &field_id);
|
||||
assert(status.error_code == Success);
|
||||
|
||||
auto col = static_cast<Collection*>(collection);
|
||||
for (auto& [target_field_id, field_meta] : col->get_schema()->get_fields()) {
|
||||
if (field_meta.is_vector()) {
|
||||
assert(field_id == target_field_id.get());
|
||||
}
|
||||
}
|
||||
assert(field_id != -1);
|
||||
|
||||
DeleteSearchPlan(plan);
|
||||
}
|
||||
|
||||
template <typename Message>
|
||||
std::vector<uint8_t>
|
||||
serialize(const Message* msg) {
|
||||
@ -1104,7 +1144,7 @@ TEST(CApiTest, ReudceNullResult) {
|
||||
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(),
|
||||
slice_nqs.data(), slice_topKs.data(), slice_nqs.size());
|
||||
assert(status.error_code == Success);
|
||||
|
||||
|
||||
auto search_result = (SearchResult*)results[0];
|
||||
auto size = search_result->result_offsets_.size();
|
||||
EXPECT_EQ(size, num_queries / 2);
|
||||
|
||||
@ -95,6 +95,7 @@ type searchRequest struct {
|
||||
cPlaceholderGroup C.CPlaceholderGroup
|
||||
timestamp Timestamp
|
||||
msgID UniqueID
|
||||
searchFieldID UniqueID
|
||||
}
|
||||
|
||||
func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*searchRequest, error) {
|
||||
@ -129,11 +130,19 @@ func newSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fieldID C.int64_t
|
||||
status = C.GetFieldID(plan.cSearchPlan, &fieldID)
|
||||
if err = HandleCStatus(&status, "get fieldID from plan failed"); err != nil {
|
||||
plan.delete()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &searchRequest{
|
||||
plan: plan,
|
||||
cPlaceholderGroup: cPlaceholderGroup,
|
||||
timestamp: req.Req.GetTravelTimestamp(),
|
||||
msgID: req.GetReq().GetBase().GetMsgID(),
|
||||
searchFieldID: int64(fieldID),
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
|
||||
@ -27,6 +27,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
func TestPlan_Plan(t *testing.T) {
|
||||
@ -124,3 +125,23 @@ func TestPlan_PlaceholderGroup(t *testing.T) {
|
||||
holder.delete()
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
||||
func TestPlan_newSearchRequest(t *testing.T) {
|
||||
iReq, _ := genSearchRequest(defaultNQ, IndexHNSW, genTestCollectionSchema())
|
||||
collection := newCollection(defaultCollectionID, genTestCollectionSchema())
|
||||
req := &querypb.SearchRequest{
|
||||
Req: iReq,
|
||||
DmlChannels: []string{defaultDMLChannel},
|
||||
SegmentIDs: []UniqueID{defaultSegmentID},
|
||||
FromShardLeader: true,
|
||||
Scope: querypb.DataScope_Historical,
|
||||
}
|
||||
searchReq, err := newSearchRequest(collection, req, req.Req.GetPlaceholderGroup())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, simpleFloatVecField.id, searchReq.searchFieldID)
|
||||
assert.EqualValues(t, defaultNQ, searchReq.getNumOfQuery())
|
||||
|
||||
searchReq.delete()
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
||||
@ -279,9 +279,13 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
|
||||
return nil, fmt.Errorf("nil search plan")
|
||||
}
|
||||
|
||||
loadIndex := s.hasLoadIndexForIndexedField(searchReq.searchFieldID)
|
||||
var searchResult SearchResult
|
||||
log.Debug("start do search on segment", zap.Int64("msgID", searchReq.msgID),
|
||||
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
|
||||
log.Debug("start do search on segment",
|
||||
zap.Int64("msgID", searchReq.msgID),
|
||||
zap.Int64("segmentID", s.segmentID),
|
||||
zap.String("segmentType", s.segmentType.String()),
|
||||
zap.Bool("loadIndex", loadIndex))
|
||||
tr := timerecord.NewTimeRecorder("cgoSearch")
|
||||
status := C.Search(s.segmentPtr, searchReq.plan.cSearchPlan, searchReq.cPlaceholderGroup,
|
||||
C.uint64_t(searchReq.timestamp), &searchResult.cSearchResult, C.int64_t(s.segmentID))
|
||||
@ -289,8 +293,11 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
|
||||
if err := HandleCStatus(&status, "Search failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debug("do search on segment done", zap.Int64("msgID", searchReq.msgID),
|
||||
zap.Int64("segmentID", s.segmentID), zap.String("segmentType", s.segmentType.String()))
|
||||
log.Debug("do search on segment done",
|
||||
zap.Int64("msgID", searchReq.msgID),
|
||||
zap.Int64("segmentID", s.segmentID),
|
||||
zap.String("segmentType", s.segmentType.String()),
|
||||
zap.Bool("loadIndex", loadIndex))
|
||||
return &searchResult, nil
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user