Forbid to get quantized vector from ChunkManager (#24334)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-05-24 23:03:27 +08:00 committed by GitHub
parent 1471da846d
commit 014387fd94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 53 deletions

View File

@ -212,9 +212,9 @@ func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequ
var results []*segcorepb.RetrieveResults var results []*segcorepb.RetrieveResults
if req.GetScope() == querypb.DataScope_Historical { if req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager) results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
} else { } else {
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager) results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs())
} }
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -23,7 +23,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/timerecord"
@ -32,7 +31,7 @@ import (
// retrieveOnSegments performs retrieve on listed segments // retrieveOnSegments performs retrieve on listed segments
// all segment ids are validated before calling this function // all segment ids are validated before calling this function
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) { func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, error) {
var ( var (
resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs)) resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs))
errs = make([]error, len(segIDs)) errs = make([]error, len(segIDs))
@ -59,7 +58,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
errs[i] = err errs[i] = err
return return
} }
if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil { if err = segment.ValidateIndexedFieldsData(ctx, result); err != nil {
errs[i] = err errs[i] = err
return return
} }
@ -87,7 +86,7 @@ func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentTy
} }
// retrieveHistorical will retrieve all the target segments in historical // retrieveHistorical will retrieve all the target segments in historical
func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error var err error
var retrieveResults []*segcorepb.RetrieveResults var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegmentIDs []UniqueID var retrieveSegmentIDs []UniqueID
@ -97,12 +96,12 @@ func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePla
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
} }
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs, vcm) retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeSealed, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
} }
// retrieveStreaming will retrieve all the target segments in streaming // retrieveStreaming will retrieve all the target segments in streaming
func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error var err error
var retrieveResults []*segcorepb.RetrieveResults var retrieveResults []*segcorepb.RetrieveResults
var retrievePartIDs []UniqueID var retrievePartIDs []UniqueID
@ -112,6 +111,6 @@ func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan
if err != nil { if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
} }
retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs, vcm) retrieveResults, err = retrieveOnSegments(ctx, manager, SegmentTypeGrowing, plan, retrieveSegmentIDs)
return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err
} }

View File

@ -126,8 +126,7 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID, suite.collectionID,
[]int64{suite.partitionID}, []int64{suite.partitionID},
[]int64{suite.sealed.ID()}, []int64{suite.sealed.ID()})
nil)
suite.NoError(err) suite.NoError(err)
suite.Len(res[0].Offset, 3) suite.Len(res[0].Offset, 3)
} }
@ -139,8 +138,7 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan, res, _, _, err := RetrieveStreaming(context.TODO(), suite.manager, plan,
suite.collectionID, suite.collectionID,
[]int64{suite.partitionID}, []int64{suite.partitionID},
[]int64{suite.growing.ID()}, []int64{suite.growing.ID()})
nil)
suite.NoError(err) suite.NoError(err)
suite.Len(res[0].Offset, 3) suite.Len(res[0].Offset, 3)
} }
@ -152,8 +150,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID, suite.collectionID,
[]int64{suite.partitionID}, []int64{suite.partitionID},
[]int64{999}, []int64{999})
nil)
suite.NoError(err) suite.NoError(err)
suite.Len(res, 0) suite.Len(res, 0)
} }
@ -166,8 +163,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() {
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID, suite.collectionID,
[]int64{suite.partitionID}, []int64{suite.partitionID},
[]int64{suite.sealed.ID()}, []int64{suite.sealed.ID()})
nil)
suite.ErrorIs(err, ErrSegmentReleased) suite.ErrorIs(err, ErrSegmentReleased)
suite.Len(res, 0) suite.Len(res, 0)
} }

View File

@ -28,6 +28,8 @@ import "C"
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"sort" "sort"
"sync" "sync"
"unsafe" "unsafe"
@ -46,7 +48,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -473,10 +474,7 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (
return dataPath, offsetInBinlog return dataPath, offsetInBinlog
} }
func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context, func (s *LocalSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error {
vcm storage.ChunkManager,
result *segcorepb.RetrieveResults,
) error {
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()), zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()), zap.Int64("partitionID", s.Partition()),
@ -484,43 +482,21 @@ func (s *LocalSegment) FillIndexedFieldsData(ctx context.Context,
) )
for _, fieldData := range result.FieldsData { for _, fieldData := range result.FieldsData {
// If the field is not vector field, no need to download data from remote.
if !typeutil.IsVectorType(fieldData.GetType()) { if !typeutil.IsVectorType(fieldData.GetType()) {
continue continue
} }
// If the vector field doesn't have indexed, vector data is in memory
// for brute force search, no need to download data from remote.
if !s.ExistIndex(fieldData.FieldId) { if !s.ExistIndex(fieldData.FieldId) {
continue continue
} }
// If the index has raw data, vector data could be obtained from index, if !s.HasRawData(fieldData.FieldId) {
// no need to download data from remote. index := s.GetIndex(fieldData.FieldId)
if s.HasRawData(fieldData.FieldId) { indexType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.IndexTypeKey, index.IndexInfo.GetIndexParams())
continue if err != nil {
}
index := s.GetIndex(fieldData.FieldId)
if index == nil {
continue
}
// TODO: optimize here. Now we'll read a whole file from storage every time we retrieve raw data by offset.
for i, offset := range result.Offset {
dataPath, dataOffset := s.GetFieldDataPath(index, offset)
endian := common.Endian
// fill field data that fieldData[i] = dataPath[offsetInBinlog*rowBytes, (offsetInBinlog+1)*rowBytes]
if err := fillFieldData(ctx, vcm, dataPath, fieldData, i, dataOffset, endian); err != nil {
log.Warn("failed to fill field data",
zap.Int64("offset", offset),
zap.String("dataPath", dataPath),
zap.Int64("dataOffset", dataOffset),
zap.Int64("fieldID", fieldData.GetFieldId()),
zap.String("fieldType", fieldData.GetType().String()),
zap.Error(err),
)
return err return err
} }
err = fmt.Errorf("output fields for %s index is not allowed", indexType)
log.Warn("validate fields failed", zap.Error(err))
return err
} }
} }

View File

@ -1,13 +1,16 @@
package segments package segments
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
@ -127,6 +130,56 @@ func (suite *SegmentSuite) TestHasRawData() {
suite.True(has) suite.True(has)
} }
func (suite *SegmentSuite) TestValidateIndexedFieldsData() {
result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{5, 4, 3, 2, 9, 8, 7, 6},
}},
},
Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6},
FieldsData: []*schemapb.FieldData{
genFieldData("int64 field", 100, schemapb.DataType_Int64,
[]int64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("float vector field", 101, schemapb.DataType_FloatVector,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
},
}
// no index
err := suite.growing.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
// with index and has raw data
suite.sealed.AddIndex(101, &IndexedFieldInfo{
IndexInfo: &querypb.FieldIndexInfo{
FieldID: 101,
EnableIndex: true,
},
})
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.NoError(err)
// index doesn't have index type
DeleteSegment(suite.sealed)
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.Error(err)
// with index but doesn't have raw data
index := suite.sealed.GetIndex(101)
_, indexParams := genIndexParams(IndexHNSW, L2)
index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams)
DeleteSegment(suite.sealed)
suite.True(suite.sealed.ExistIndex(101))
err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result)
suite.Error(err)
}
func TestSegment(t *testing.T) { func TestSegment(t *testing.T) {
suite.Run(t, new(SegmentSuite)) suite.Run(t, new(SegmentSuite))
} }

View File

@ -45,6 +45,9 @@ type TestGetVectorSuite struct {
metricType string metricType string
pkType schemapb.DataType pkType schemapb.DataType
vecType schemapb.DataType vecType schemapb.DataType
// expected
searchFailed bool
} }
func (s *TestGetVectorSuite) run() { func (s *TestGetVectorSuite) run() {
@ -172,6 +175,11 @@ func (s *TestGetVectorSuite) run() {
searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq) searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
s.Require().NoError(err) s.Require().NoError(err)
if s.searchFailed {
s.Require().NotEqual(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
s.T().Logf("reason:%s", searchResp.GetStatus().GetReason())
return
}
s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
result := searchResp.GetResults() result := searchResp.GetResults()
@ -253,6 +261,7 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -263,6 +272,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -273,6 +283,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
s.run() s.run()
} }
@ -283,6 +294,7 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = true
s.run() s.run()
} }
@ -293,6 +305,7 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -303,6 +316,7 @@ func (s *TestGetVectorSuite) TestGetVector_IP() {
s.metricType = distance.IP s.metricType = distance.IP
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -313,6 +327,7 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_VarChar s.pkType = schemapb.DataType_VarChar
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -323,6 +338,7 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() {
s.metricType = distance.JACCARD s.metricType = distance.JACCARD
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_BinaryVector s.vecType = schemapb.DataType_BinaryVector
s.searchFailed = false
s.run() s.run()
} }
@ -334,6 +350,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
s.metricType = distance.L2 s.metricType = distance.L2
s.pkType = schemapb.DataType_Int64 s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.searchFailed = false
s.run() s.run()
} }
@ -344,6 +361,7 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
// s.metricType = distance.L2 // s.metricType = distance.L2
// s.pkType = schemapb.DataType_Int64 // s.pkType = schemapb.DataType_Int64
// s.vecType = schemapb.DataType_FloatVector // s.vecType = schemapb.DataType_FloatVector
// s.searchFailed = false
// s.run() // s.run()
//} //}

View File

@ -1418,9 +1418,10 @@ class TestQueryOperation(TestcaseBase):
assert collection_w.has_index()[0] assert collection_w.has_index()[0]
res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records')
collection_w.load() collection_w.load()
error = {ct.err_code: 1, ct.err_msg: 'not allowed'}
collection_w.query(default_term_expr, output_fields=fields, collection_w.query(default_term_expr, output_fields=fields,
check_task=CheckTasks.check_query_results, check_task=CheckTasks.err_res,
check_items={exp_res: res, "with_vec": True}) check_items=error)
@pytest.mark.tags(CaseLabel.L1) @pytest.mark.tags(CaseLabel.L1)
def test_query_output_binary_vec_field_after_index(self): def test_query_output_binary_vec_field_after_index(self):