diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index d3aacaa948..5fddba3161 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -212,9 +212,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ var results []*segcorepb.RetrieveResults if req.GetScope() == querypb.DataScope_Historical { - results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager) + results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs()) } else { - results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager) + results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs()) } if err != nil { return nil, err diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 7c2932e251..cb14e05ef1 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -23,7 +23,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" - "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -32,7 +31,7 @@ import ( // retrieveOnSegments performs retrieve on listed segments // all segment ids are validated before calling this function -func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) { +func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, error) { var ( resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs)) errs = make([]error, len(segIDs)) @@ -59,7 +58,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy errs[i] = err return } - if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil { + if err = segment.ValidateIndexedFieldsData(ctx, result); err != nil { errs[i] = err return } @@ -87,7 +86,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy } // retrieveHistorical will retrieve all the target segments in historical -func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { +func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { var err error var retrieveResults []*segcorepb.RetrieveResults var retrieveSegmentIDs []UniqueID @@ -97,12 +96,12 @@ func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePla return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } - retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs, vcm) + retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs) return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err } // retrieveStreaming will retrieve all the target segments in streaming -func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { +func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { var err error var retrieveResults []*segcorepb.RetrieveResults var retrievePartIDs []UniqueID @@ -112,6 +111,6 @@ func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan if err != nil { return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } - retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs, vcm) + retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs) return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index b357f4699f..3aec11297b 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -126,8 +126,7 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, suite.collectionID, []int64{suite.partitionID}, - []int64{suite.sealed.ID()}, - nil) + []int64{suite.sealed.ID()}) suite.NoError(err) suite.Len(res[0].Offset, 3) } @@ -139,8 +138,7 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan, suite.collectionID, []int64{suite.partitionID}, - []int64{suite.growing.ID()}, - nil) + []int64{suite.growing.ID()}) suite.NoError(err) suite.Len(res[0].Offset, 3) } @@ -152,8 +150,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, suite.collectionID, []int64{suite.partitionID}, - []int64{999}, - nil) + []int64{999}) suite.NoError(err) suite.Len(res, 0) } @@ -166,8 +163,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, suite.collectionID, []int64{suite.partitionID}, - []int64{suite.sealed.ID()}, - nil) + []int64{suite.sealed.ID()}) suite.ErrorIs(err, ErrSegmentReleased) suite.Len(res, 0) } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 10e75c2ee8..00f1822758 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -28,6 +28,8 @@ import "C" import ( "context" "fmt" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/funcutil" "sort" "sync" "unsafe" @@ -46,7 +48,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/segcorepb" pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -473,10 +474,7 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) ( return dataPath, offsetInBinlog } -func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context, - vcm storage.ChunkManager, - result *segcorepb.RetrieveResults, -) error { +func (s *LocalSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error { log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), @@ -484,43 +482,21 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context, ) for _, fieldData := range result.FieldsData { - // If the field is not vector field, no need to download data from remote. if !typeutil.IsVectorType(fieldData.GetType()) { continue } - // If the vector field doesn't have indexed, vector data is in memory - // for brute force search, no need to download data from remote. if !s.ExistIndex(fieldData.FieldId) { continue } - // If the index has raw data, vector data could be obtained from index, - // no need to download data from remote. - if s.HasRawData(fieldData.FieldId) { - continue - } - - index := s.GetIndex(fieldData.FieldId) - if index == nil { - continue - } - - // TODO: optimize here. Now we'll read a whole file from storage every time we retrieve raw data by offset. - for i, offset := range result.Offset { - dataPath, dataOffset := s.GetFieldDataPath(index, offset) - endian := common.Endian - - // fill field data that fieldData[i] = dataPath[offsetInBinlog*rowBytes, (offsetInBinlog+1)*rowBytes] - if err := fillFieldData(ctx, vcm, dataPath, fieldData, i, dataOffset, endian); err != nil { - log.Warn("failed to fill field data", - zap.Int64("offset", offset), - zap.String("dataPath", dataPath), - zap.Int64("dataOffset", dataOffset), - zap.Int64("fieldID", fieldData.GetFieldId()), - zap.String("fieldType", fieldData.GetType().String()), - zap.Error(err), - ) + if !s.HasRawData(fieldData.FieldId) { + index := s.GetIndex(fieldData.FieldId) + indexType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.IndexTypeKey, index.IndexInfo.GetIndexParams()) + if err != nil { return err } + err = fmt.Errorf("output fields for %s index is not allowed", indexType) + log.Warn("validate fields failed", zap.Error(err)) + return err } } diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 5b24aabc24..11808386aa 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -1,13 +1,16 @@ package segments import ( + "context" "testing" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -127,6 +130,56 @@ func (suite *SegmentSuite) TestHasRawData() { suite.True(has) } +func (suite *SegmentSuite) TestValidateIndexedFieldsData() { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{5, 4, 3, 2, 9, 8, 7, 6}, + }}, + }, + Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6}, + FieldsData: []*schemapb.FieldData{ + genFieldData("int64 field", 100, schemapb.DataType_Int64, + []int64{5, 4, 3, 2, 9, 8, 7, 6}, 1), + genFieldData("float vector field", 101, schemapb.DataType_FloatVector, + []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), + }, + } + + // no index + err := suite.growing.ValidateIndexedFieldsData(context.Background(), result) + suite.NoError(err) + err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) + suite.NoError(err) + + // with index and has raw data + suite.sealed.AddIndex(101, &IndexedFieldInfo{ + IndexInfo: &querypb.FieldIndexInfo{ + FieldID: 101, + EnableIndex: true, + }, + }) + suite.True(suite.sealed.ExistIndex(101)) + err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) + suite.NoError(err) + + // index doesn't have index type + DeleteSegment(suite.sealed) + suite.True(suite.sealed.ExistIndex(101)) + err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) + suite.Error(err) + + // with index but doesn't have raw data + index := suite.sealed.GetIndex(101) + _, indexParams := genIndexParams(IndexHNSW, L2) + index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams) + DeleteSegment(suite.sealed) + suite.True(suite.sealed.ExistIndex(101)) + err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) + suite.Error(err) +} + func TestSegment(t *testing.T) { suite.Run(t, new(SegmentSuite)) } diff --git a/tests/integration/getvector/get_vector_test.go b/tests/integration/getvector/get_vector_test.go index d20fe2fc63..37813ca18b 100644 --- a/tests/integration/getvector/get_vector_test.go +++ b/tests/integration/getvector/get_vector_test.go @@ -45,6 +45,9 @@ type TestGetVectorSuite struct { metricType string pkType schemapb.DataType vecType schemapb.DataType + + // expected + searchFailed bool } func (s *TestGetVectorSuite) run() { @@ -172,6 +175,11 @@ func (s *TestGetVectorSuite) run() { searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq) s.Require().NoError(err) + if s.searchFailed { + s.Require().NotEqual(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + s.T().Logf("reason:%s", searchResp.GetStatus().GetReason()) + return + } s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) result := searchResp.GetResults() @@ -253,6 +261,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -263,6 +272,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -273,6 +283,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = true s.run() } @@ -283,6 +294,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = true s.run() } @@ -293,6 +305,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -303,6 +316,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() { s.metricType = distance.IP s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -313,6 +327,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() { s.metricType = distance.L2 s.pkType = schemapb.DataType_VarChar s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -323,6 +338,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() { s.metricType = distance.JACCARD s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_BinaryVector + s.searchFailed = false s.run() } @@ -334,6 +350,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() { s.metricType = distance.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector + s.searchFailed = false s.run() } @@ -344,6 +361,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() { // s.metricType = distance.L2 // s.pkType = schemapb.DataType_Int64 // s.vecType = schemapb.DataType_FloatVector +// s.searchFailed = false // s.run() //} diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 708c444d12..e480276382 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -1418,9 +1418,10 @@ class TestQueryOperation(TestcaseBase): assert collection_w.has_index()[0] res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') collection_w.load() + error = {ct.err_code: 1, ct.err_msg: 'not allowed'} collection_w.query(default_term_expr, output_fields=fields, - check_task=CheckTasks.check_query_results, - check_items={exp_res: res, "with_vec": True}) + check_task=CheckTasks.err_res, + check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_query_output_binary_vec_field_after_index(self):