mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 10:08:42 +08:00
fix: Support mvcc with hybrid serach (#30114)
issue: https://github.com/milvus-io/milvus/issues/29656 /kind bug Signed-off-by: xige-16 <xi.ge@zilliz.com> --------- Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
32914a3ddf
commit
060c8603a3
@ -329,3 +329,10 @@ func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...gr
|
|||||||
return client.Delete(ctx, req)
|
return client.Delete(ctx, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch performs replica hybrid search tasks in QueryNode.
|
||||||
|
func (c *Client) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest, _ ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||||
|
return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*querypb.HybridSearchResult, error) {
|
||||||
|
return client.HybridSearch(ctx, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -374,3 +374,8 @@ func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
|
|||||||
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||||
return s.querynode.Delete(ctx, req)
|
return s.querynode.Delete(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch performs hybrid search of streaming/historical replica on QueryNode.
|
||||||
|
func (s *Server) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
|
return s.querynode.HybridSearch(ctx, req)
|
||||||
|
}
|
||||||
|
|||||||
@ -511,6 +511,61 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.C
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockQueryNode) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
var r0 *querypb.HybridSearchResult
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||||
|
return rf(_a0, _a1)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||||
|
r1 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQueryNode_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||||
|
type MockQueryNode_HybridSearch_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// HybridSearch is a helper method to define mock.On call
|
||||||
|
// - _a0 context.Context
|
||||||
|
// - _a1 *querypb.HybridSearchRequest
|
||||||
|
func (_e *MockQueryNode_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNode_HybridSearch_Call {
|
||||||
|
return &MockQueryNode_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNode_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNode_HybridSearch_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNode_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNode_HybridSearch_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNode_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNode_HybridSearch_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// Init provides a mock function with given fields:
|
// Init provides a mock function with given fields:
|
||||||
func (_m *MockQueryNode) Init() error {
|
func (_m *MockQueryNode) Init() error {
|
||||||
ret := _m.Called()
|
ret := _m.Called()
|
||||||
|
|||||||
@ -632,6 +632,76 @@ func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch provides a mock function with given fields: ctx, in, opts
|
||||||
|
func (_m *MockQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||||
|
_va := make([]interface{}, len(opts))
|
||||||
|
for _i := range opts {
|
||||||
|
_va[_i] = opts[_i]
|
||||||
|
}
|
||||||
|
var _ca []interface{}
|
||||||
|
_ca = append(_ca, ctx, in)
|
||||||
|
_ca = append(_ca, _va...)
|
||||||
|
ret := _m.Called(_ca...)
|
||||||
|
|
||||||
|
var r0 *querypb.HybridSearchResult
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)); ok {
|
||||||
|
return rf(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) *querypb.HybridSearchResult); ok {
|
||||||
|
r0 = rf(ctx, in, opts...)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) error); ok {
|
||||||
|
r1 = rf(ctx, in, opts...)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQueryNodeClient_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||||
|
type MockQueryNodeClient_HybridSearch_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// HybridSearch is a helper method to define mock.On call
|
||||||
|
// - ctx context.Context
|
||||||
|
// - in *querypb.HybridSearchRequest
|
||||||
|
// - opts ...grpc.CallOption
|
||||||
|
func (_e *MockQueryNodeClient_Expecter) HybridSearch(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_HybridSearch_Call {
|
||||||
|
return &MockQueryNodeClient_HybridSearch_Call{Call: _e.mock.On("HybridSearch",
|
||||||
|
append([]interface{}{ctx, in}, opts...)...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeClient_HybridSearch_Call) Run(run func(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_HybridSearch_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
variadicArgs := make([]grpc.CallOption, len(args)-2)
|
||||||
|
for i, a := range args[2:] {
|
||||||
|
if a != nil {
|
||||||
|
variadicArgs[i] = a.(grpc.CallOption)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest), variadicArgs...)
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeClient_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeClient_HybridSearch_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeClient_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest, ...grpc.CallOption) (*querypb.HybridSearchResult, error)) *MockQueryNodeClient_HybridSearch_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// LoadPartitions provides a mock function with given fields: ctx, in, opts
|
// LoadPartitions provides a mock function with given fields: ctx, in, opts
|
||||||
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||||
_va := make([]interface{}, len(opts))
|
_va := make([]interface{}, len(opts))
|
||||||
|
|||||||
@ -104,6 +104,18 @@ message SearchRequest {
|
|||||||
string username = 18;
|
string username = 18;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HybridSearchRequest {
|
||||||
|
common.MsgBase base = 1;
|
||||||
|
int64 reqID = 2;
|
||||||
|
int64 dbID = 3;
|
||||||
|
int64 collectionID = 4;
|
||||||
|
repeated int64 partitionIDs = 5;
|
||||||
|
repeated SearchRequest reqs = 6;
|
||||||
|
uint64 mvcc_timestamp = 11;
|
||||||
|
uint64 guarantee_timestamp = 12;
|
||||||
|
uint64 timeout_timestamp = 13;
|
||||||
|
}
|
||||||
|
|
||||||
message SearchResults {
|
message SearchResults {
|
||||||
common.MsgBase base = 1;
|
common.MsgBase base = 1;
|
||||||
common.Status status = 2;
|
common.Status status = 2;
|
||||||
|
|||||||
@ -71,6 +71,7 @@ service QueryNode {
|
|||||||
|
|
||||||
rpc GetStatistics(GetStatisticsRequest) returns (internal.GetStatisticsResponse) {}
|
rpc GetStatistics(GetStatisticsRequest) returns (internal.GetStatisticsResponse) {}
|
||||||
rpc Search(SearchRequest) returns (internal.SearchResults) {}
|
rpc Search(SearchRequest) returns (internal.SearchResults) {}
|
||||||
|
rpc HybridSearch(HybridSearchRequest) returns (HybridSearchResult) {}
|
||||||
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {}
|
rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {}
|
||||||
rpc Query(QueryRequest) returns (internal.RetrieveResults) {}
|
rpc Query(QueryRequest) returns (internal.RetrieveResults) {}
|
||||||
rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){}
|
rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){}
|
||||||
@ -328,6 +329,20 @@ message SearchRequest {
|
|||||||
int32 total_channel_num = 6;
|
int32 total_channel_num = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HybridSearchRequest {
|
||||||
|
internal.HybridSearchRequest req = 1;
|
||||||
|
repeated string dml_channels = 2;
|
||||||
|
int32 total_channel_num = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HybridSearchResult {
|
||||||
|
common.MsgBase base = 1;
|
||||||
|
common.Status status = 2;
|
||||||
|
repeated internal.SearchResults results = 3;
|
||||||
|
internal.CostAggregation costAggregation = 4;
|
||||||
|
map<string, uint64> channels_mvcc = 5;
|
||||||
|
}
|
||||||
|
|
||||||
message QueryRequest {
|
message QueryRequest {
|
||||||
internal.RetrieveRequest req = 1;
|
internal.RetrieveRequest req = 1;
|
||||||
repeated string dml_channels = 2;
|
repeated string dml_channels = 2;
|
||||||
|
|||||||
@ -2784,6 +2784,13 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||||||
qt := &hybridSearchTask{
|
qt := &hybridSearchTask{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
Condition: NewTaskCondition(ctx),
|
Condition: NewTaskCondition(ctx),
|
||||||
|
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(
|
||||||
|
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||||
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||||
|
),
|
||||||
|
ReqID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
request: request,
|
request: request,
|
||||||
tr: timerecord.NewTimeRecorder(method),
|
tr: timerecord.NewTimeRecorder(method),
|
||||||
qc: node.queryCoord,
|
qc: node.queryCoord,
|
||||||
@ -2831,7 +2838,7 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
|||||||
|
|
||||||
log.Debug(
|
log.Debug(
|
||||||
rpcEnqueued(method),
|
rpcEnqueued(method),
|
||||||
zap.Uint64("timestamp", qt.request.Base.Timestamp),
|
zap.Uint64("timestamp", qt.Base.Timestamp),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := qt.WaitToFinish(); err != nil {
|
if err := qt.WaitToFinish(); err != nil {
|
||||||
|
|||||||
@ -120,7 +120,7 @@ func NewReScorer(reqs []*milvuspb.SearchRequest, rankParams []*commonpb.KeyValue
|
|||||||
return nil, errors.New("The type of rank param k should be float")
|
return nil, errors.New("The type of rank param k should be float")
|
||||||
}
|
}
|
||||||
if k <= 0 || k >= maxRRFParamsValue {
|
if k <= 0 || k >= maxRRFParamsValue {
|
||||||
return nil, errors.New("The rank params k should be in range (0, 16384)")
|
return nil, errors.New(fmt.Sprintf("The rank params k should be in range (0, %d)", maxRRFParamsValue))
|
||||||
}
|
}
|
||||||
log.Debug("rrf params", zap.Float64("k", k))
|
log.Debug("rrf params", zap.Float64("k", k))
|
||||||
for i := range reqs {
|
for i := range reqs {
|
||||||
|
|||||||
160
internal/proxy/search_util.go
Normal file
160
internal/proxy/search_util.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/cockroachdb/errors"
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initSearchRequest(ctx context.Context, t *searchTask) error {
|
||||||
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request")
|
||||||
|
defer sp.End()
|
||||||
|
|
||||||
|
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||||
|
// fetch search_growing from search param
|
||||||
|
var ignoreGrowing bool
|
||||||
|
var err error
|
||||||
|
for i, kv := range t.request.GetSearchParams() {
|
||||||
|
if kv.GetKey() == IgnoreGrowingKey {
|
||||||
|
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("parse search growing failed")
|
||||||
|
}
|
||||||
|
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||||
|
|
||||||
|
// Manually update nq if not set.
|
||||||
|
nq, err := getNq(t.request)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get nq", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Check if nq is valid:
|
||||||
|
// https://milvus.io/docs/limitations.md
|
||||||
|
if err := validateNQLimit(nq); err != nil {
|
||||||
|
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||||
|
}
|
||||||
|
t.SearchRequest.Nq = nq
|
||||||
|
log = log.With(zap.Int64("nq", nq))
|
||||||
|
|
||||||
|
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("fail to get output field ids", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||||
|
|
||||||
|
partitionNames := t.request.GetPartitionNames()
|
||||||
|
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||||
|
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||||
|
if err != nil || len(annsField) == 0 {
|
||||||
|
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
||||||
|
if len(vecFields) == 0 {
|
||||||
|
return errors.New(AnnsFieldKey + " not found in schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
if enableMultipleVectorFields && len(vecFields) > 1 {
|
||||||
|
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
||||||
|
}
|
||||||
|
|
||||||
|
annsField = vecFields[0].Name
|
||||||
|
}
|
||||||
|
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if queryInfo.GroupByFieldId != 0 {
|
||||||
|
t.SearchRequest.IgnoreGrowing = true
|
||||||
|
// for group by operation, currently, we ignore growing segments
|
||||||
|
}
|
||||||
|
t.offset = offset
|
||||||
|
|
||||||
|
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to create query plan", zap.Error(err),
|
||||||
|
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||||
|
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||||
|
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
|
||||||
|
}
|
||||||
|
log.Debug("create query plan",
|
||||||
|
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
||||||
|
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
||||||
|
|
||||||
|
if t.partitionKeyMode {
|
||||||
|
expr, err := ParseExprFromPlan(plan)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to parse expr", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
partitionKeys := ParsePartitionKeys(expr)
|
||||||
|
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to assign partition keys", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||||
|
}
|
||||||
|
|
||||||
|
plan.OutputFieldIds = outputFieldIDs
|
||||||
|
|
||||||
|
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||||
|
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||||
|
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||||
|
|
||||||
|
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to estimate result size", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if estimateSize >= requeryThreshold {
|
||||||
|
t.requery = true
|
||||||
|
plan.OutputFieldIds = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("proxy init search request",
|
||||||
|
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||||
|
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||||
|
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get partition ids", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||||
|
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||||
|
|
||||||
|
// Set username of this search request for feature like task scheduling.
|
||||||
|
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||||
|
t.SearchRequest.Username = username
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -14,10 +14,11 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
@ -32,9 +33,11 @@ const (
|
|||||||
type hybridSearchTask struct {
|
type hybridSearchTask struct {
|
||||||
Condition
|
Condition
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
*internalpb.HybridSearchRequest
|
||||||
|
|
||||||
result *milvuspb.SearchResults
|
result *milvuspb.SearchResults
|
||||||
request *milvuspb.HybridSearchRequest
|
request *milvuspb.HybridSearchRequest
|
||||||
|
searchTasks []*searchTask
|
||||||
|
|
||||||
tr *timerecord.TimeRecorder
|
tr *timerecord.TimeRecorder
|
||||||
schema *schemaInfo
|
schema *schemaInfo
|
||||||
@ -45,12 +48,11 @@ type hybridSearchTask struct {
|
|||||||
qc types.QueryCoordClient
|
qc types.QueryCoordClient
|
||||||
node types.ProxyComponent
|
node types.ProxyComponent
|
||||||
lb LBPolicy
|
lb LBPolicy
|
||||||
queryChannelsTs map[string]Timestamp
|
|
||||||
|
|
||||||
collectionID UniqueID
|
|
||||||
|
|
||||||
|
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
|
||||||
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
||||||
reScorers []reScorer
|
reScorers []reScorer
|
||||||
|
queryChannelsTs map[string]Timestamp
|
||||||
rankParams *rankParams
|
rankParams *rankParams
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +65,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(t.request.Requests) > defaultMaxSearchRequest {
|
if len(t.request.Requests) > defaultMaxSearchRequest {
|
||||||
return errors.New("maximum of ann search requests is 1024")
|
return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))
|
||||||
}
|
}
|
||||||
for _, req := range t.request.GetRequests() {
|
for _, req := range t.request.GetRequests() {
|
||||||
nq, err := getNq(req)
|
nq, err := getNq(req)
|
||||||
@ -78,12 +80,15 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Base.MsgType = commonpb.MsgType_Search
|
||||||
|
t.Base.SourceID = paramtable.GetNodeID()
|
||||||
|
|
||||||
collectionName := t.request.CollectionName
|
collectionName := t.request.CollectionName
|
||||||
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
|
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.collectionID = collID
|
t.CollectionID = collID
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||||
@ -113,6 +118,82 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
t.requery = true
|
t.requery = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
|
||||||
|
if err2 != nil {
|
||||||
|
log.Warn("Proxy::hybridSearchTask::PreExecute failed to GetCollectionInfo from cache",
|
||||||
|
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2))
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
guaranteeTs := t.request.GetGuaranteeTimestamp()
|
||||||
|
var consistencyLevel commonpb.ConsistencyLevel
|
||||||
|
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
||||||
|
if useDefaultConsistency {
|
||||||
|
consistencyLevel = collectionInfo.consistencyLevel
|
||||||
|
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||||
|
} else {
|
||||||
|
consistencyLevel = t.request.GetConsistencyLevel()
|
||||||
|
// Compatibility logic, parse guarantee timestamp
|
||||||
|
if consistencyLevel == 0 && guaranteeTs > 0 {
|
||||||
|
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
|
||||||
|
} else {
|
||||||
|
// parse from guarantee timestamp and user input consistency level
|
||||||
|
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
||||||
|
if err != nil {
|
||||||
|
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.searchTasks = make([]*searchTask, len(t.request.GetRequests()))
|
||||||
|
for index := range t.request.Requests {
|
||||||
|
searchReq := t.request.Requests[index]
|
||||||
|
|
||||||
|
if len(searchReq.GetCollectionName()) == 0 {
|
||||||
|
searchReq.CollectionName = t.request.GetCollectionName()
|
||||||
|
} else if searchReq.GetCollectionName() != t.request.GetCollectionName() {
|
||||||
|
return errors.New(fmt.Sprintf("inconsistent collection name in hybrid search request, "+
|
||||||
|
"expect %s, actual %s", searchReq.GetCollectionName(), t.request.GetCollectionName()))
|
||||||
|
}
|
||||||
|
|
||||||
|
searchReq.PartitionNames = t.request.GetPartitionNames()
|
||||||
|
searchReq.ConsistencyLevel = consistencyLevel
|
||||||
|
searchReq.GuaranteeTimestamp = guaranteeTs
|
||||||
|
searchReq.UseDefaultConsistency = useDefaultConsistency
|
||||||
|
searchReq.OutputFields = nil
|
||||||
|
|
||||||
|
t.searchTasks[index] = &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
collectionName: collectionName,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(
|
||||||
|
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||||
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||||
|
),
|
||||||
|
ReqID: paramtable.GetNodeID(),
|
||||||
|
DbID: 0, // todo
|
||||||
|
CollectionID: collID,
|
||||||
|
},
|
||||||
|
request: searchReq,
|
||||||
|
schema: t.schema,
|
||||||
|
tr: timerecord.NewTimeRecorder("hybrid search"),
|
||||||
|
qc: t.qc,
|
||||||
|
node: t.node,
|
||||||
|
lb: t.lb,
|
||||||
|
|
||||||
|
partitionKeyMode: partitionKeyMode,
|
||||||
|
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||||
|
}
|
||||||
|
err := initSearchRequest(ctx, t.searchTasks[index])
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("init hybrid search request failed", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Debug("hybrid search preExecute done.",
|
log.Debug("hybrid search preExecute done.",
|
||||||
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
|
zap.Uint64("guarantee_ts", t.request.GetGuaranteeTimestamp()),
|
||||||
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
|
zap.Bool("use_default_consistency", t.request.GetUseDefaultConsistency()),
|
||||||
@ -121,56 +202,65 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
|
||||||
|
for _, searchTask := range t.searchTasks {
|
||||||
|
t.HybridSearchRequest.Reqs = append(t.HybridSearchRequest.Reqs, searchTask.SearchRequest)
|
||||||
|
}
|
||||||
|
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
|
||||||
|
hybridSearchReq.GetBase().TargetID = nodeID
|
||||||
|
req := &querypb.HybridSearchRequest{
|
||||||
|
Req: hybridSearchReq,
|
||||||
|
DmlChannels: []string{channel},
|
||||||
|
TotalChannelNum: int32(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
|
||||||
|
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
|
||||||
|
zap.Int64("nodeID", nodeID),
|
||||||
|
zap.String("channel", channel))
|
||||||
|
|
||||||
|
var result *querypb.HybridSearchResult
|
||||||
|
var err error
|
||||||
|
|
||||||
|
result, err = qn.HybridSearch(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("QueryNode hybrid search return error", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||||
|
log.Warn("QueryNode is not shardLeader")
|
||||||
|
return errInvalidShardLeaders
|
||||||
|
}
|
||||||
|
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Warn("QueryNode hybrid search result error",
|
||||||
|
zap.String("reason", result.GetStatus().GetReason()))
|
||||||
|
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to hybrid search on QueryNode %d", nodeID)
|
||||||
|
}
|
||||||
|
t.resultBuf.Insert(result)
|
||||||
|
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) Execute(ctx context.Context) error {
|
func (t *hybridSearchTask) Execute(ctx context.Context) error {
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-Execute")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
|
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute hybrid search %d", t.ID()))
|
||||||
defer tr.CtxElapse(ctx, "done")
|
defer tr.CtxElapse(ctx, "done")
|
||||||
|
|
||||||
futures := make([]*conc.Future[*milvuspb.SearchResults], len(t.request.Requests))
|
t.resultBuf = typeutil.NewConcurrentSet[*querypb.HybridSearchResult]()
|
||||||
for index := range t.request.Requests {
|
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||||
searchReq := t.request.Requests[index]
|
db: t.request.GetDbName(),
|
||||||
future := conc.Go(func() (*milvuspb.SearchResults, error) {
|
collectionID: t.CollectionID,
|
||||||
searchReq.TravelTimestamp = t.request.GetTravelTimestamp()
|
collectionName: t.request.GetCollectionName(),
|
||||||
searchReq.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
|
nq: 1,
|
||||||
searchReq.NotReturnAllMeta = t.request.GetNotReturnAllMeta()
|
exec: t.hybridSearchShard,
|
||||||
searchReq.ConsistencyLevel = t.request.GetConsistencyLevel()
|
|
||||||
searchReq.UseDefaultConsistency = t.request.GetUseDefaultConsistency()
|
|
||||||
searchReq.OutputFields = nil
|
|
||||||
|
|
||||||
return t.node.Search(ctx, searchReq)
|
|
||||||
})
|
})
|
||||||
futures[index] = future
|
|
||||||
}
|
|
||||||
|
|
||||||
err := conc.AwaitAll(futures...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.Warn("hybrid search execute failed", zap.Error(err))
|
||||||
}
|
return errors.Wrap(err, "failed to hybrid search")
|
||||||
|
|
||||||
t.reScorers, err = NewReScorer(t.request.GetRequests(), t.request.GetRankParams())
|
|
||||||
if err != nil {
|
|
||||||
log.Info("generate reScorer failed", zap.Any("rank params", t.request.GetRankParams()), zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
|
||||||
for i, future := range futures {
|
|
||||||
err = future.Err()
|
|
||||||
if err != nil {
|
|
||||||
log.Debug("QueryNode search result error", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result := futures[i].Value()
|
|
||||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
||||||
log.Debug("QueryNode search result error",
|
|
||||||
zap.String("reason", result.GetStatus().GetReason()))
|
|
||||||
return merr.Error(result.GetStatus())
|
|
||||||
}
|
|
||||||
|
|
||||||
t.reScorers[i].reScore(result)
|
|
||||||
t.multipleRecallResults.Insert(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("hybrid search execute done.")
|
log.Debug("hybrid search execute done.")
|
||||||
@ -194,7 +284,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||||||
|
|
||||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New(LimitKey + " not found in search_params")
|
return nil, errors.New(LimitKey + " not found in rank_params")
|
||||||
}
|
}
|
||||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -235,16 +325,59 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *hybridSearchTask) collectHybridSearchResults(ctx context.Context) error {
|
||||||
|
select {
|
||||||
|
case <-t.TraceCtx().Done():
|
||||||
|
log.Ctx(ctx).Warn("hybrid search task wait to finish timeout!")
|
||||||
|
return fmt.Errorf("hybrid search task wait to finish timeout, msgID=%d", t.ID())
|
||||||
|
default:
|
||||||
|
log.Ctx(ctx).Debug("all hybrid searches are finished or canceled")
|
||||||
|
t.resultBuf.Range(func(res *querypb.HybridSearchResult) bool {
|
||||||
|
for index, searchResult := range res.GetResults() {
|
||||||
|
t.searchTasks[index].resultBuf.Insert(searchResult)
|
||||||
|
}
|
||||||
|
log.Ctx(ctx).Debug("proxy receives one hybrid search result",
|
||||||
|
zap.Int64("sourceID", res.GetBase().GetSourceID()))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
t.multipleRecallResults = typeutil.NewConcurrentSet[*milvuspb.SearchResults]()
|
||||||
|
for i, searchTask := range t.searchTasks {
|
||||||
|
err := searchTask.PostExecute(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.reScorers[i].reScore(searchTask.result)
|
||||||
|
t.multipleRecallResults.Insert(searchTask.result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
func (t *hybridSearchTask) PostExecute(ctx context.Context) error {
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-HybridSearch-PostExecute")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.collectionID), zap.String("collName", t.request.GetCollectionName()))
|
log := log.Ctx(ctx).With(zap.Int64("collID", t.CollectionID), zap.String("collName", t.request.GetCollectionName()))
|
||||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
|
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy postExecute hybrid search %d", t.ID()))
|
||||||
defer func() {
|
defer func() {
|
||||||
tr.CtxElapse(ctx, "done")
|
tr.CtxElapse(ctx, "done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
err := t.collectHybridSearchResults(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to collect hybrid search results", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.queryChannelsTs = make(map[string]uint64)
|
||||||
|
for _, r := range t.resultBuf.Collect() {
|
||||||
|
for ch, ts := range r.GetChannelsMvcc() {
|
||||||
|
t.queryChannelsTs[ch] = ts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
primaryFieldSchema, err := t.schema.GetPkField()
|
primaryFieldSchema, err := t.schema.GetPkField()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("failed to get primary field schema", zap.Error(err))
|
log.Warn("failed to get primary field schema", zap.Error(err))
|
||||||
@ -304,9 +437,8 @@ func (t *hybridSearchTask) Requery() error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:Xige-16 refine the mvcc functionality of hybrid search
|
|
||||||
// TODO:silverxia move partitionIDs to hybrid search level
|
// TODO:silverxia move partitionIDs to hybrid search level
|
||||||
return doRequery(t.ctx, t.collectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func rankSearchResultData(ctx context.Context,
|
func rankSearchResultData(ctx context.Context,
|
||||||
@ -436,11 +568,11 @@ func (t *hybridSearchTask) TraceCtx() context.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) ID() UniqueID {
|
func (t *hybridSearchTask) ID() UniqueID {
|
||||||
return t.request.Base.MsgID
|
return t.Base.MsgID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) SetID(uid UniqueID) {
|
func (t *hybridSearchTask) SetID(uid UniqueID) {
|
||||||
t.request.Base.MsgID = uid
|
t.Base.MsgID = uid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) Name() string {
|
func (t *hybridSearchTask) Name() string {
|
||||||
@ -448,24 +580,24 @@ func (t *hybridSearchTask) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) Type() commonpb.MsgType {
|
func (t *hybridSearchTask) Type() commonpb.MsgType {
|
||||||
return t.request.Base.MsgType
|
return t.Base.MsgType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) BeginTs() Timestamp {
|
func (t *hybridSearchTask) BeginTs() Timestamp {
|
||||||
return t.request.Base.Timestamp
|
return t.Base.Timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) EndTs() Timestamp {
|
func (t *hybridSearchTask) EndTs() Timestamp {
|
||||||
return t.request.Base.Timestamp
|
return t.Base.Timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) SetTs(ts Timestamp) {
|
func (t *hybridSearchTask) SetTs(ts Timestamp) {
|
||||||
t.request.Base.Timestamp = ts
|
t.Base.Timestamp = ts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *hybridSearchTask) OnEnqueue() error {
|
func (t *hybridSearchTask) OnEnqueue() error {
|
||||||
t.request.Base = commonpbutil.NewMsgBase()
|
t.Base = commonpbutil.NewMsgBase()
|
||||||
t.request.Base.MsgType = commonpb.MsgType_Search
|
t.Base.MsgType = commonpb.MsgType_Search
|
||||||
t.request.Base.SourceID = paramtable.GetNodeID()
|
t.Base.SourceID = paramtable.GetNodeID()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
@ -69,6 +70,7 @@ func TestHybridSearchTask_PreExecute(t *testing.T) {
|
|||||||
task := &hybridSearchTask{
|
task := &hybridSearchTask{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
Condition: NewTaskCondition(ctx),
|
Condition: NewTaskCondition(ctx),
|
||||||
|
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||||
request: &milvuspb.HybridSearchRequest{
|
request: &milvuspb.HybridSearchRequest{
|
||||||
CollectionName: collName,
|
CollectionName: collName,
|
||||||
Requests: reqs,
|
Requests: reqs,
|
||||||
@ -225,6 +227,7 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
|||||||
result: &milvuspb.SearchResults{
|
result: &milvuspb.SearchResults{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
},
|
},
|
||||||
|
HybridSearchRequest: &internalpb.HybridSearchRequest{},
|
||||||
request: &milvuspb.HybridSearchRequest{
|
request: &milvuspb.HybridSearchRequest{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
Requests: []*milvuspb.SearchRequest{
|
Requests: []*milvuspb.SearchRequest{
|
||||||
@ -266,12 +269,12 @@ func TestHybridSearchTask_ErrExecute(t *testing.T) {
|
|||||||
task.ctx = ctx
|
task.ctx = ctx
|
||||||
assert.NoError(t, task.PreExecute(ctx))
|
assert.NoError(t, task.PreExecute(ctx))
|
||||||
|
|
||||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
|
|
||||||
qn.ExpectedCalls = nil
|
qn.ExpectedCalls = nil
|
||||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||||
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
|
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
},
|
},
|
||||||
@ -291,6 +294,10 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||||||
mgr := NewMockShardClientManager(t)
|
mgr := NewMockShardClientManager(t)
|
||||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||||
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||||
|
qn.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&querypb.HybridSearchResult{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
Status: merr.Success(),
|
||||||
|
}, nil)
|
||||||
|
|
||||||
t.Run("Test empty result", func(t *testing.T) {
|
t.Run("Test empty result", func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@ -313,6 +320,9 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||||||
qc: nil,
|
qc: nil,
|
||||||
tr: timerecord.NewTimeRecorder("search"),
|
tr: timerecord.NewTimeRecorder("search"),
|
||||||
schema: schema,
|
schema: schema,
|
||||||
|
HybridSearchRequest: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
request: &milvuspb.HybridSearchRequest{
|
request: &milvuspb.HybridSearchRequest{
|
||||||
Base: &commonpb.MsgBase{
|
Base: &commonpb.MsgBase{
|
||||||
MsgType: commonpb.MsgType_Search,
|
MsgType: commonpb.MsgType_Search,
|
||||||
@ -320,6 +330,7 @@ func TestHybridSearchTask_PostExecute(t *testing.T) {
|
|||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
RankParams: rankParams,
|
RankParams: rankParams,
|
||||||
},
|
},
|
||||||
|
resultBuf: typeutil.NewConcurrentSet[*querypb.HybridSearchResult](),
|
||||||
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
multipleRecallResults: typeutil.NewConcurrentSet[*milvuspb.SearchResults](),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,6 +56,7 @@ type searchTask struct {
|
|||||||
collectionName string
|
collectionName string
|
||||||
schema *schemaInfo
|
schema *schemaInfo
|
||||||
requery bool
|
requery bool
|
||||||
|
partitionKeyMode bool
|
||||||
|
|
||||||
userOutputFields []string
|
userOutputFields []string
|
||||||
|
|
||||||
@ -250,22 +250,21 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
|
||||||
|
|
||||||
t.SearchRequest.DbID = 0 // todo
|
t.SearchRequest.DbID = 0 // todo
|
||||||
t.SearchRequest.CollectionID = collID
|
t.SearchRequest.CollectionID = collID
|
||||||
|
log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName))
|
||||||
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("get collection schema failed", zap.Error(err))
|
log.Warn("get collection schema failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("is partition key mode failed", zap.Error(err))
|
log.Warn("is partition key mode failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
||||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,123 +276,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
log.Debug("translate output fields",
|
log.Debug("translate output fields",
|
||||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||||
|
|
||||||
// fetch search_growing from search param
|
err = initSearchRequest(ctx, t)
|
||||||
var ignoreGrowing bool
|
|
||||||
for i, kv := range t.request.GetSearchParams() {
|
|
||||||
if kv.GetKey() == IgnoreGrowingKey {
|
|
||||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("parse search growing failed")
|
log.Debug("init search request failed", zap.Error(err))
|
||||||
}
|
|
||||||
t.request.SearchParams = append(t.request.GetSearchParams()[:i], t.request.GetSearchParams()[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
|
||||||
|
|
||||||
// Manually update nq if not set.
|
|
||||||
nq, err := getNq(t.request)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to get nq", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Check if nq is valid:
|
|
||||||
// https://milvus.io/docs/limitations.md
|
|
||||||
if err := validateNQLimit(nq); err != nil {
|
|
||||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
|
||||||
}
|
|
||||||
t.SearchRequest.Nq = nq
|
|
||||||
log = log.With(zap.Int64("nq", nq))
|
|
||||||
|
|
||||||
outputFieldIDs, err := getOutputFieldIDs(t.schema, t.request.GetOutputFields())
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("fail to get output field ids", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
|
||||||
|
|
||||||
partitionNames := t.request.GetPartitionNames()
|
|
||||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
|
||||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
|
||||||
if err != nil || len(annsField) == 0 {
|
|
||||||
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
|
|
||||||
if len(vecFields) == 0 {
|
|
||||||
return errors.New(AnnsFieldKey + " not found in schema")
|
|
||||||
}
|
|
||||||
|
|
||||||
if enableMultipleVectorFields && len(vecFields) > 1 {
|
|
||||||
return errors.New("multiple anns_fields exist, please specify a anns_field in search_params")
|
|
||||||
}
|
|
||||||
|
|
||||||
annsField = vecFields[0].Name
|
|
||||||
}
|
|
||||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams(), t.schema.CollectionSchema)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if queryInfo.GroupByFieldId != 0 {
|
|
||||||
t.SearchRequest.IgnoreGrowing = true
|
|
||||||
// for group by operation, currently, we ignore growing segments
|
|
||||||
}
|
|
||||||
t.offset = offset
|
|
||||||
|
|
||||||
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to create query plan", zap.Error(err),
|
|
||||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
|
||||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
|
||||||
return merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err)
|
|
||||||
}
|
|
||||||
log.Debug("create query plan",
|
|
||||||
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
|
|
||||||
zap.String("anns field", annsField), zap.Any("query info", queryInfo))
|
|
||||||
|
|
||||||
if partitionKeyMode {
|
|
||||||
expr, err := ParseExprFromPlan(plan)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to parse expr", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
partitionKeys := ParsePartitionKeys(expr)
|
|
||||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), collectionName, partitionKeys)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to assign partition keys", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
|
||||||
}
|
|
||||||
|
|
||||||
plan.OutputFieldIds = outputFieldIDs
|
|
||||||
|
|
||||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
|
||||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
|
||||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
|
||||||
|
|
||||||
estimateSize, err := t.estimateResultSize(nq, t.SearchRequest.Topk)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to estimate result size", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if estimateSize >= requeryThreshold {
|
|
||||||
t.requery = true
|
|
||||||
plan.OutputFieldIds = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Proxy::searchTask::PreExecute",
|
|
||||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
|
||||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
|
||||||
}
|
|
||||||
|
|
||||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
|
||||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to get partition ids", zap.Error(err))
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,17 +306,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||||
|
|
||||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
|
||||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
|
||||||
|
|
||||||
// Set username of this search request for feature like task scheduling.
|
|
||||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
|
||||||
t.SearchRequest.Username = username
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("search PreExecute done.",
|
log.Debug("search PreExecute done.",
|
||||||
zap.Uint64("guarantee_ts", guaranteeTs),
|
zap.Uint64("guarantee_ts", guaranteeTs),
|
||||||
zap.Bool("use_default_consistency", useDefaultConsistency),
|
zap.Bool("use_default_consistency", useDefaultConsistency),
|
||||||
|
|||||||
@ -469,6 +469,61 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) RunAndReturn(run func(con
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch provides a mock function with given fields: _a0, _a1
|
||||||
|
func (_m *MockQueryNodeServer) HybridSearch(_a0 context.Context, _a1 *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|
||||||
|
var r0 *querypb.HybridSearchResult
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||||
|
return rf(_a0, _a1)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||||
|
r0 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||||
|
r1 = rf(_a0, _a1)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQueryNodeServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||||
|
type MockQueryNodeServer_HybridSearch_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// HybridSearch is a helper method to define mock.On call
|
||||||
|
// - _a0 context.Context
|
||||||
|
// - _a1 *querypb.HybridSearchRequest
|
||||||
|
func (_e *MockQueryNodeServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_HybridSearch_Call {
|
||||||
|
return &MockQueryNodeServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *querypb.HybridSearchRequest)) *MockQueryNodeServer_HybridSearch_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeServer_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockQueryNodeServer_HybridSearch_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockQueryNodeServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockQueryNodeServer_HybridSearch_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// LoadPartitions provides a mock function with given fields: _a0, _a1
|
// LoadPartitions provides a mock function with given fields: _a0, _a1
|
||||||
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||||
ret := _m.Called(_a0, _a1)
|
ret := _m.Called(_a0, _a1)
|
||||||
|
|||||||
@ -62,6 +62,7 @@ type ShardDelegator interface {
|
|||||||
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
|
GetSegmentInfo(readable bool) (sealed []SnapshotItem, growing []SegmentEntry)
|
||||||
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
|
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
|
||||||
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
|
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
|
||||||
|
HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)
|
||||||
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
|
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
|
||||||
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
|
QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error
|
||||||
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
|
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
|
||||||
@ -184,6 +185,44 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
|
|||||||
return nodeReq
|
return nodeReq
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search preforms search operation on shard.
|
||||||
|
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
|
||||||
|
log := sd.getLogger(ctx)
|
||||||
|
if req.Req.IgnoreGrowing {
|
||||||
|
growing = []SegmentEntry{}
|
||||||
|
}
|
||||||
|
|
||||||
|
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||||
|
log.Debug("search segments...",
|
||||||
|
zap.Int("sealedNum", sealedNum),
|
||||||
|
zap.Int("growingNum", len(growing)),
|
||||||
|
)
|
||||||
|
|
||||||
|
req, err := optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to optimize search params", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||||
|
return worker.SearchSegments(ctx, req)
|
||||||
|
}, "Search", log)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("Delegator search failed", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Delegator search done")
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Search preforms search operation on shard.
|
// Search preforms search operation on shard.
|
||||||
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||||
log := sd.getLogger(ctx)
|
log := sd.getLogger(ctx)
|
||||||
@ -229,39 +268,113 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||||
})
|
})
|
||||||
|
|
||||||
if req.Req.IgnoreGrowing {
|
return sd.search(ctx, req, sealed, growing)
|
||||||
growing = []SegmentEntry{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
// HybridSearch preforms hybrid search operation on shard.
|
||||||
log.Debug("search segments...",
|
func (sd *shardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
zap.Int("sealedNum", sealedNum),
|
log := sd.getLogger(ctx)
|
||||||
zap.Int("growingNum", len(growing)),
|
if err := sd.lifetime.Add(lifetime.IsWorking); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer sd.lifetime.Done()
|
||||||
|
|
||||||
|
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||||
|
log.Warn("deletgator received hybrid search request not belongs to it",
|
||||||
|
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||||
)
|
)
|
||||||
|
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||||
|
}
|
||||||
|
|
||||||
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
partitions := req.GetReq().GetPartitionIDs()
|
||||||
|
if !sd.collection.ExistPartition(partitions...) {
|
||||||
|
return nil, merr.WrapErrPartitionNotLoaded(partitions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait tsafe
|
||||||
|
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||||
|
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("delegator hybrid search failed to wait tsafe", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||||
|
req.Req.MvccTimestamp = tSafe
|
||||||
|
}
|
||||||
|
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
|
||||||
|
fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel).
|
||||||
|
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||||
|
|
||||||
|
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("delegator failed to hybrid search, current distribution is not serviceable")
|
||||||
|
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
|
||||||
|
}
|
||||||
|
defer sd.distribution.Unpin(version)
|
||||||
|
existPartitions := sd.collection.GetPartitions()
|
||||||
|
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
|
||||||
|
return funcutil.SliceContain(existPartitions, segment.PartitionID)
|
||||||
|
})
|
||||||
|
|
||||||
|
futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetReqs()))
|
||||||
|
for index := range req.GetReq().GetReqs() {
|
||||||
|
request := req.GetReq().Reqs[index]
|
||||||
|
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||||
|
searchReq := &querypb.SearchRequest{
|
||||||
|
Req: request,
|
||||||
|
DmlChannels: req.GetDmlChannels(),
|
||||||
|
TotalChannelNum: req.GetTotalChannelNum(),
|
||||||
|
FromShardLeader: true,
|
||||||
|
}
|
||||||
|
searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()
|
||||||
|
searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()
|
||||||
|
if searchReq.GetReq().GetMvccTimestamp() == 0 {
|
||||||
|
searchReq.GetReq().MvccTimestamp = tSafe
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := sd.search(ctx, searchReq, sealed, growing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("failed to optimize search params", zap.Error(err))
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
return segments.ReduceSearchResults(ctx,
|
||||||
|
results,
|
||||||
|
searchReq.Req.GetNq(),
|
||||||
|
searchReq.Req.GetTopk(),
|
||||||
|
searchReq.Req.GetMetricType())
|
||||||
|
})
|
||||||
|
futures[index] = future
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conc.AwaitAll(futures...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
ret := &querypb.HybridSearchResult{
|
||||||
return worker.SearchSegments(ctx, req)
|
Status: merr.Success(),
|
||||||
}, "Search", log)
|
Results: make([]*internalpb.SearchResults, len(futures)),
|
||||||
if err != nil {
|
|
||||||
log.Warn("Delegator search failed", zap.Error(err))
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Delegator search done")
|
channelsMvcc := make(map[string]uint64)
|
||||||
|
for i, future := range futures {
|
||||||
|
result := future.Value()
|
||||||
|
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Debug("delegator hybrid search failed",
|
||||||
|
zap.String("reason", result.GetStatus().GetReason()))
|
||||||
|
return nil, merr.Error(result.GetStatus())
|
||||||
|
}
|
||||||
|
|
||||||
return results, nil
|
ret.Results[i] = result
|
||||||
|
for ch, ts := range result.GetChannelsMvcc() {
|
||||||
|
channelsMvcc[ch] = ts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret.ChannelsMvcc = channelsMvcc
|
||||||
|
|
||||||
|
log.Debug("Delegator hybrid search done")
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
|
func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error {
|
||||||
|
|||||||
@ -469,6 +469,251 @@ func (s *DelegatorSuite) TestSearch() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DelegatorSuite) TestHybridSearch() {
|
||||||
|
s.delegator.Start()
|
||||||
|
paramtable.SetNodeID(1)
|
||||||
|
s.initSegments()
|
||||||
|
s.Run("normal", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
workers := make(map[int64]*cluster.MockWorker)
|
||||||
|
worker1 := &cluster.MockWorker{}
|
||||||
|
worker2 := &cluster.MockWorker{}
|
||||||
|
|
||||||
|
workers[1] = worker1
|
||||||
|
workers[2] = worker2
|
||||||
|
|
||||||
|
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||||
|
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||||
|
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||||
|
s.True(req.GetFromShardLeader())
|
||||||
|
if req.GetScope() == querypb.DataScope_Streaming {
|
||||||
|
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||||
|
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||||
|
}
|
||||||
|
if req.GetScope() == querypb.DataScope_Historical {
|
||||||
|
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||||
|
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||||
|
}
|
||||||
|
}).Return(&internalpb.SearchResults{}, nil)
|
||||||
|
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||||
|
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||||
|
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||||
|
s.True(req.GetFromShardLeader())
|
||||||
|
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||||
|
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||||
|
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||||
|
}).Return(&internalpb.SearchResults{}, nil)
|
||||||
|
|
||||||
|
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||||
|
return workers[nodeID]
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
results, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
Reqs: []*internalpb.SearchRequest{
|
||||||
|
{Base: commonpbutil.NewMsgBase()},
|
||||||
|
{Base: commonpbutil.NewMsgBase()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(2, len(results.Results))
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("partition_not_loaded", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
// not load partation -1,will return error
|
||||||
|
PartitionIDs: []int64{-1},
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.True(errors.Is(err, merr.ErrPartitionNotLoaded))
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("worker_return_error", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
workers := make(map[int64]*cluster.MockWorker)
|
||||||
|
worker1 := &cluster.MockWorker{}
|
||||||
|
worker2 := &cluster.MockWorker{}
|
||||||
|
|
||||||
|
workers[1] = worker1
|
||||||
|
workers[2] = worker2
|
||||||
|
|
||||||
|
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
|
||||||
|
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||||
|
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||||
|
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||||
|
s.True(req.GetFromShardLeader())
|
||||||
|
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||||
|
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||||
|
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||||
|
}).Return(&internalpb.SearchResults{}, nil)
|
||||||
|
|
||||||
|
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||||
|
return workers[nodeID]
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
Reqs: []*internalpb.SearchRequest{
|
||||||
|
{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("worker_return_failure_code", func() {
|
||||||
|
defer func() {
|
||||||
|
s.workerManager.ExpectedCalls = nil
|
||||||
|
}()
|
||||||
|
workers := make(map[int64]*cluster.MockWorker)
|
||||||
|
worker1 := &cluster.MockWorker{}
|
||||||
|
worker2 := &cluster.MockWorker{}
|
||||||
|
|
||||||
|
workers[1] = worker1
|
||||||
|
workers[2] = worker2
|
||||||
|
|
||||||
|
worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "mocked error",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||||
|
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||||
|
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||||
|
s.True(req.GetFromShardLeader())
|
||||||
|
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||||
|
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||||
|
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||||
|
}).Return(&internalpb.SearchResults{}, nil)
|
||||||
|
|
||||||
|
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||||
|
return workers[nodeID]
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
Reqs: []*internalpb.SearchRequest{
|
||||||
|
{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("wrong_channel", func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
|
DmlChannels: []string{"non_exist_channel"},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("wait_tsafe_timeout", func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
GuaranteeTimestamp: 10100,
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("tsafe_behind_max_lag", func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 10001,
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("distribution_not_serviceable", func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sd, ok := s.delegator.(*shardDelegator)
|
||||||
|
s.Require().True(ok)
|
||||||
|
sd.distribution.AddOfflines(1001)
|
||||||
|
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("cluster_not_serviceable", func() {
|
||||||
|
s.delegator.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := s.delegator.HybridSearch(ctx, &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: commonpbutil.NewMsgBase(),
|
||||||
|
},
|
||||||
|
DmlChannels: []string{s.vchannelName},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DelegatorSuite) TestQuery() {
|
func (s *DelegatorSuite) TestQuery() {
|
||||||
s.delegator.Start()
|
s.delegator.Start()
|
||||||
paramtable.SetNodeID(1)
|
paramtable.SetNodeID(1)
|
||||||
|
|||||||
@ -253,6 +253,61 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int6
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch provides a mock function with given fields: ctx, req
|
||||||
|
func (_m *MockShardDelegator) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
|
ret := _m.Called(ctx, req)
|
||||||
|
|
||||||
|
var r0 *querypb.HybridSearchResult
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)); ok {
|
||||||
|
return rf(ctx, req)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, *querypb.HybridSearchRequest) *querypb.HybridSearchResult); ok {
|
||||||
|
r0 = rf(ctx, req)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(*querypb.HybridSearchResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, *querypb.HybridSearchRequest) error); ok {
|
||||||
|
r1 = rf(ctx, req)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockShardDelegator_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch'
|
||||||
|
type MockShardDelegator_HybridSearch_Call struct {
|
||||||
|
*mock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// HybridSearch is a helper method to define mock.On call
|
||||||
|
// - ctx context.Context
|
||||||
|
// - req *querypb.HybridSearchRequest
|
||||||
|
func (_e *MockShardDelegator_Expecter) HybridSearch(ctx interface{}, req interface{}) *MockShardDelegator_HybridSearch_Call {
|
||||||
|
return &MockShardDelegator_HybridSearch_Call{Call: _e.mock.On("HybridSearch", ctx, req)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockShardDelegator_HybridSearch_Call) Run(run func(ctx context.Context, req *querypb.HybridSearchRequest)) *MockShardDelegator_HybridSearch_Call {
|
||||||
|
_c.Call.Run(func(args mock.Arguments) {
|
||||||
|
run(args[0].(context.Context), args[1].(*querypb.HybridSearchRequest))
|
||||||
|
})
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockShardDelegator_HybridSearch_Call) Return(_a0 *querypb.HybridSearchResult, _a1 error) *MockShardDelegator_HybridSearch_Call {
|
||||||
|
_c.Call.Return(_a0, _a1)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_c *MockShardDelegator_HybridSearch_Call) RunAndReturn(run func(context.Context, *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error)) *MockShardDelegator_HybridSearch_Call {
|
||||||
|
_c.Call.Return(run)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
||||||
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||||
ret := _m.Called(ctx, infos, version)
|
ret := _m.Called(ctx, infos, version)
|
||||||
|
|||||||
@ -401,6 +401,63 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (node *QueryNode) hybridSearchChannel(ctx context.Context, req *querypb.HybridSearchRequest, channel string) (*querypb.HybridSearchResult, error) {
|
||||||
|
log := log.Ctx(ctx).With(
|
||||||
|
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||||
|
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||||
|
zap.String("channel", channel),
|
||||||
|
)
|
||||||
|
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||||
|
|
||||||
|
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer node.lifetime.Done()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.TotalLabel, metrics.Leader).Inc()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.FailLabel, metrics.Leader).Inc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Debug("start to search channel")
|
||||||
|
searchCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// From Proxy
|
||||||
|
tr := timerecord.NewTimeRecorder("hybridSearchDelegator")
|
||||||
|
// get delegator
|
||||||
|
sd, ok := node.delegators.Get(channel)
|
||||||
|
if !ok {
|
||||||
|
err := merr.WrapErrChannelNotFound(channel)
|
||||||
|
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// do hybrid search
|
||||||
|
result, err := sd.HybridSearch(searchCtx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to hybrid search on delegator", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , traceID = %s, vChannel = %s",
|
||||||
|
traceID,
|
||||||
|
channel,
|
||||||
|
))
|
||||||
|
|
||||||
|
// update metric to prometheus
|
||||||
|
latency := tr.ElapseSpan()
|
||||||
|
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
|
||||||
|
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.SuccessLabel, metrics.Leader).Inc()
|
||||||
|
for _, searchReq := range req.GetReq().GetReqs() {
|
||||||
|
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetNq()))
|
||||||
|
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(searchReq.GetTopk()))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
|
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
|
||||||
log := log.Ctx(ctx).With(
|
log := log.Ctx(ctx).With(
|
||||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||||
|
|||||||
@ -821,6 +821,114 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HybridSearch performs replica search tasks.
|
||||||
|
func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSearchRequest) (*querypb.HybridSearchResult, error) {
|
||||||
|
log := log.Ctx(ctx).With(
|
||||||
|
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
|
||||||
|
zap.Strings("channels", req.GetDmlChannels()))
|
||||||
|
|
||||||
|
log.Debug("Received HybridSearchRequest",
|
||||||
|
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
|
||||||
|
zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()))
|
||||||
|
|
||||||
|
tr := timerecord.NewTimeRecorderWithTrace(ctx, "HybridSearchRequest")
|
||||||
|
|
||||||
|
if err := node.lifetime.Add(merr.IsHealthy); err != nil {
|
||||||
|
return &querypb.HybridSearchResult{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
Status: merr.Status(err),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
defer node.lifetime.Done()
|
||||||
|
|
||||||
|
err := merr.CheckTargetID(req.GetReq().GetBase())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("target ID check failed", zap.Error(err))
|
||||||
|
return &querypb.HybridSearchResult{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
Status: merr.Status(err),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &querypb.HybridSearchResult{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
SourceID: paramtable.GetNodeID(),
|
||||||
|
},
|
||||||
|
Status: merr.Success(),
|
||||||
|
}
|
||||||
|
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
|
||||||
|
if collection == nil {
|
||||||
|
resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
MultipleResults := make([]*querypb.HybridSearchResult, len(req.GetDmlChannels()))
|
||||||
|
runningGp, runningCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
for i, ch := range req.GetDmlChannels() {
|
||||||
|
ch := ch
|
||||||
|
req := &querypb.HybridSearchRequest{
|
||||||
|
Req: req.Req,
|
||||||
|
DmlChannels: []string{ch},
|
||||||
|
TotalChannelNum: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
i := i
|
||||||
|
runningGp.Go(func() error {
|
||||||
|
ret, err := node.hybridSearchChannel(runningCtx, req, ch)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := merr.Error(ret.GetStatus()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
MultipleResults[i] = ret
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := runningGp.Wait(); err != nil {
|
||||||
|
resp.Status = merr.Status(err)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tr.RecordSpan()
|
||||||
|
channelsMvcc := make(map[string]uint64)
|
||||||
|
for i, searchReq := range req.GetReq().GetReqs() {
|
||||||
|
toReduceResults := make([]*internalpb.SearchResults, len(MultipleResults))
|
||||||
|
for index, hs := range MultipleResults {
|
||||||
|
toReduceResults[index] = hs.Results[i]
|
||||||
|
}
|
||||||
|
result, err := segments.ReduceSearchResults(ctx, toReduceResults, searchReq.GetNq(), searchReq.GetTopk(), searchReq.GetMetricType())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to reduce search results", zap.Error(err))
|
||||||
|
resp.Status = merr.Status(err)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
for ch, ts := range result.GetChannelsMvcc() {
|
||||||
|
channelsMvcc[ch] = ts
|
||||||
|
}
|
||||||
|
resp.Results = append(resp.Results, result)
|
||||||
|
}
|
||||||
|
resp.ChannelsMvcc = channelsMvcc
|
||||||
|
|
||||||
|
reduceLatency := tr.RecordSpan()
|
||||||
|
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.HybridSearchLabel, metrics.ReduceShards).
|
||||||
|
Observe(float64(reduceLatency.Milliseconds()))
|
||||||
|
|
||||||
|
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
|
||||||
|
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.HybridSearchLabel).
|
||||||
|
Add(float64(proto.Size(req)))
|
||||||
|
|
||||||
|
if resp.GetCostAggregation() != nil {
|
||||||
|
resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// only used for delegator query segments from worker
|
// only used for delegator query segments from worker
|
||||||
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||||
resp := &internalpb.RetrieveResults{
|
resp := &internalpb.RetrieveResults{
|
||||||
|
|||||||
@ -1323,6 +1323,47 @@ func (suite *ServiceSuite) TestSearchSegments_Failed() {
|
|||||||
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
|
suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *ServiceSuite) TestHybridSearch_Concurrent() {
|
||||||
|
ctx := context.Background()
|
||||||
|
// pre
|
||||||
|
suite.TestWatchDmChannelsInt64()
|
||||||
|
suite.TestLoadSegments_Int64()
|
||||||
|
|
||||||
|
concurrency := 16
|
||||||
|
futures := make([]*conc.Future[*querypb.HybridSearchResult], 0, concurrency)
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
future := conc.Go(func() (*querypb.HybridSearchResult, error) {
|
||||||
|
creq1, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||||
|
suite.NoError(err)
|
||||||
|
creq2, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType)
|
||||||
|
suite.NoError(err)
|
||||||
|
req := &querypb.HybridSearchRequest{
|
||||||
|
Req: &internalpb.HybridSearchRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgID: rand.Int63(),
|
||||||
|
TargetID: suite.node.session.ServerID,
|
||||||
|
},
|
||||||
|
CollectionID: suite.collectionID,
|
||||||
|
PartitionIDs: suite.partitionIDs,
|
||||||
|
MvccTimestamp: typeutil.MaxTimestamp,
|
||||||
|
Reqs: []*internalpb.SearchRequest{creq1, creq2},
|
||||||
|
},
|
||||||
|
DmlChannels: []string{suite.vchannel},
|
||||||
|
}
|
||||||
|
|
||||||
|
return suite.node.HybridSearch(ctx, req)
|
||||||
|
})
|
||||||
|
futures = append(futures, future)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := conc.AwaitAll(futures...)
|
||||||
|
suite.NoError(err)
|
||||||
|
|
||||||
|
for i := range futures {
|
||||||
|
suite.True(merr.Ok(futures[i].Value().GetStatus()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
// pre
|
// pre
|
||||||
|
|||||||
@ -82,6 +82,10 @@ func (m *GrpcQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequ
|
|||||||
return &internalpb.SearchResults{}, m.Err
|
return &internalpb.SearchResults{}, m.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *GrpcQueryNodeClient) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||||
|
return &querypb.HybridSearchResult{}, m.Err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
func (m *GrpcQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||||
return &internalpb.SearchResults{}, m.Err
|
return &internalpb.SearchResults{}, m.Err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -93,6 +93,10 @@ func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest
|
|||||||
return qn.QueryNode.Search(ctx, in)
|
return qn.QueryNode.Search(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (qn *qnServerWrapper) HybridSearch(ctx context.Context, in *querypb.HybridSearchRequest, opts ...grpc.CallOption) (*querypb.HybridSearchResult, error) {
|
||||||
|
return qn.QueryNode.HybridSearch(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) {
|
||||||
return qn.QueryNode.SearchSegments(ctx, in)
|
return qn.QueryNode.SearchSegments(ctx, in)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user