diff --git a/cmd/components/root_coord.go b/cmd/components/root_coord.go index 0201558d09..e26a5c50fd 100644 --- a/cmd/components/root_coord.go +++ b/cmd/components/root_coord.go @@ -58,8 +58,8 @@ func (rc *RootCoord) Run() error { // Stop terminates service func (rc *RootCoord) Stop() error { - if err := rc.svr.Stop(); err != nil { - return err + if rc.svr != nil { + return rc.svr.Stop() } return nil } diff --git a/internal/datacoord/coordinator_broker.go b/internal/datacoord/coordinator_broker.go index c0d2eecb21..91252d2ec4 100644 --- a/internal/datacoord/coordinator_broker.go +++ b/internal/datacoord/coordinator_broker.go @@ -19,12 +19,13 @@ import ( "context" "time" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "go.uber.org/zap" ) @@ -145,12 +146,9 @@ func (b *CoordinatorBroker) HasCollection(ctx context.Context, collectionID int6 if resp == nil { return false, errNilResponse } - if resp.Status.ErrorCode == commonpb.ErrorCode_Success { - return true, nil - } - statusErr := common.NewStatusError(resp.Status.ErrorCode, resp.Status.Reason) - if common.IsCollectionNotExistError(statusErr) { + err = merr.Error(resp.GetStatus()) + if errors.Is(err, merr.ErrCollectionNotFound) { return false, nil } - return false, statusErr + return err == nil, err } diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index 4ba823dc5d..2b6701f4d8 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -258,10 +258,8 @@ func (s *Server) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRe }, nil } ret := &indexpb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - State: commonpb.IndexState_Finished, + Status: merr.Status(nil), + State: commonpb.IndexState_Finished, } indexInfo := &indexpb.IndexInfo{ @@ -301,9 +299,7 @@ func (s *Server) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegme } ret := &indexpb.GetSegmentIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), States: make([]*indexpb.SegmentIndexState, 0), } indexID2CreateTs := s.meta.GetIndexIDByName(req.GetCollectionID(), req.GetIndexName()) @@ -510,9 +506,7 @@ func (s *Server) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetInde log.Info("GetIndexBuildProgress success", zap.Int64("collectionID", req.GetCollectionID()), zap.String("indexName", req.GetIndexName())) return &indexpb.GetIndexBuildProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IndexedRows: indexInfo.IndexedRows, TotalRows: indexInfo.TotalRows, PendingIndexRows: indexInfo.PendingIndexRows, @@ -580,9 +574,7 @@ func (s *Server) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRe } log.Info("DescribeIndex success", zap.String("indexName", req.GetIndexName())) return &indexpb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IndexInfos: indexInfos, }, nil } @@ -640,9 +632,7 @@ func (s *Server) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexSt log.Debug("GetIndexStatisticsResponse success", zap.String("indexName", req.GetIndexName())) return &indexpb.GetIndexStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IndexInfos: indexInfos, }, nil } @@ -668,9 +658,7 @@ func (s *Server) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) ( return errResp, nil } - ret := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } + ret := merr.Status(nil) indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { @@ -722,9 +710,7 @@ func (s *Server) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoReq }, nil } ret := &indexpb.GetIndexInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), SegmentInfo: map[int64]*indexpb.SegmentInfo{}, } diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index 5b8a4ce757..e02fafd795 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -18,11 +18,11 @@ package datacoord import ( "context" - "fmt" "sync/atomic" "time" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/tsoutil" clientv3 "go.etcd.io/etcd/client/v3" @@ -395,10 +395,9 @@ func (m *mockRootCoordService) HasCollection(ctx context.Context, req *milvuspb. func (m *mockRootCoordService) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { // return not exist if req.CollectionID == -1 { - err := common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %d", req.CollectionID)) + err := merr.WrapErrCollectionNotFound(req.GetCollectionID()) return &milvuspb.DescribeCollectionResponse{ - // TODO: use commonpb.ErrorCode_CollectionNotExists. SDK use commonpb.ErrorCode_UnexpectedError now. - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error()}, + Status: merr.Status(err), }, nil } return &milvuspb.DescribeCollectionResponse{ diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index b16a68472f..9d1b881a33 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -2455,7 +2455,7 @@ func TestShouldDropChannel(t *testing.T) { //myRoot.code = commonpb.ErrorCode_CollectionNotExists myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}, + Status: merr.Status(merr.WrapErrCollectionNotFound(-1)), CollectionID: -1, }, nil).Once() assert.True(t, svr.handler.CheckShouldDropChannel("ch99", -1)) @@ -2482,7 +2482,7 @@ func TestShouldDropChannel(t *testing.T) { t.Run("collection name in kv, collection not exist", func(t *testing.T) { myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}, + Status: merr.Status(merr.WrapErrCollectionNotFound(-1)), CollectionID: -1, }, nil).Once() assert.True(t, svr.handler.CheckShouldDropChannel("ch1", -1)) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index db024146ce..b8692f8bd6 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -54,10 +54,8 @@ func (s *Server) isClosed() bool { // GetTimeTickChannel legacy API, returns time tick channel name func (s *Server) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Value: Params.CommonCfg.DataCoordTimeTick.GetValue(), + Status: merr.Status(nil), + Value: Params.CommonCfg.DataCoordTimeTick.GetValue(), }, nil } @@ -203,18 +201,13 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI CollectionID: r.CollectionID, PartitionID: r.PartitionID, ExpireTime: allocation.ExpireTime, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), } assigns = append(assigns, result) } } return &datapb.AssignSegmentIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), SegIDAssignments: assigns, }, nil } @@ -340,10 +333,8 @@ func (s *Server) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart // GetSegmentInfoChannel legacy API, returns segment info statistics channel func (s *Server) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Value: Params.CommonCfg.DataCoordSegmentInfo.GetValue(), + Status: merr.Status(nil), + Value: Params.CommonCfg.DataCoordSegmentInfo.GetValue(), }, nil } @@ -591,9 +582,7 @@ func (s *Server) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStat }, nil } return &datapb.SetSegmentStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil } @@ -619,10 +608,7 @@ func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta Role: "datacoord", StateCode: code, }, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), } return resp, nil } @@ -974,10 +960,7 @@ func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Configuations: configList, }, nil } @@ -1400,9 +1383,7 @@ func (s *Server) UpdateSegmentStatistics(ctx context.Context, req *datapb.Update return resp, nil } s.updateSegmentStatistics(req.GetStats()) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // UpdateChannelCheckpoint updates channel checkpoint in dataCoord. @@ -1424,9 +1405,7 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update return resp, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // ReportDataNodeTtMsgs send datenode timetick messages to dataCoord. @@ -1590,9 +1569,7 @@ func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe Reason: err.Error(), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // UnsetIsImportingState unsets the isImporting states of the given segments. @@ -1615,9 +1592,7 @@ func (s *Server) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsI Reason: reportErr.Error(), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // MarkSegmentsDropped marks the given segments as `Dropped`. @@ -1638,9 +1613,7 @@ func (s *Server) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmen ErrorCode: commonpb.ErrorCode_UnexpectedError, }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { @@ -1674,16 +1647,12 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt Properties: properties, } s.meta.AddCollection(collInfo) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } clonedColl.Properties = properties s.meta.AddCollection(clonedColl) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { diff --git a/internal/datanode/metrics_info.go b/internal/datanode/metrics_info.go index 9726b899c9..d4f43e9d9e 100644 --- a/internal/datanode/metrics_info.go +++ b/internal/datanode/metrics_info.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" @@ -131,10 +132,7 @@ func (node *DataNode) getSystemInfoMetrics(ctx context.Context, req *milvuspb.Ge } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/datanode/services.go b/internal/datanode/services.go index 329ded7b9a..84a1096726 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -417,9 +417,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) }() importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), TaskId: req.GetImportTask().TaskId, DatanodeId: paramtable.GetNodeID(), State: commonpb.ImportState_ImportStarted, @@ -519,9 +517,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) return returnFailFunc("failed to import files", err) } - resp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } + resp := merr.Status(nil) return resp, nil } @@ -654,9 +650,7 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor } ds.flushingSegCache.Remove(req.GetSegmentId()) return &datapb.AddImportSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), ChannelPos: posID, }, nil } @@ -705,9 +699,7 @@ func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil // ignore the returned error, since even report failed the segments still can be cleaned // retry 10 times, if the rootcoord is down, the report function will cost 20+ seconds importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), TaskId: req.GetImportTask().TaskId, DatanodeId: paramtable.GetNodeID(), State: commonpb.ImportState_ImportStarted, diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 386784ca3c..cab2894a0d 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -1058,10 +1058,7 @@ func (s *Server) GetProxyMetrics(ctx context.Context, request *milvuspb.GetMetri func (s *Server) GetVersion(ctx context.Context, request *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error) { buildTags := os.Getenv(metricsinfo.GitBuildTagsEnvKey) return &milvuspb.GetVersionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Version: buildTags, }, nil } diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 8c117ef63d..f3fcdd77a0 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -56,6 +56,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -288,9 +289,7 @@ func (i *IndexNode) GetComponentStates(ctx context.Context) (*milvuspb.Component ret := &milvuspb.ComponentStates{ State: stateInfo, SubcomponentStates: nil, // todo add subcomponents states - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } log.RatedInfo(10, "IndexNode Component states", @@ -305,9 +304,7 @@ func (i *IndexNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRes log.RatedInfo(10, "get IndexNode time tick channel ...") return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil } @@ -315,9 +312,7 @@ func (i *IndexNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRes func (i *IndexNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { log.RatedInfo(10, "get IndexNode statistics channel ...") return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil } @@ -352,10 +347,7 @@ func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.Show } return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Configuations: configList, }, nil } diff --git a/internal/indexnode/indexnode_mock.go b/internal/indexnode/indexnode_mock.go index f1eac9f016..76bd8bef8d 100644 --- a/internal/indexnode/indexnode_mock.go +++ b/internal/indexnode/indexnode_mock.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -84,22 +85,16 @@ func NewIndexNodeMock() *Mock { StateCode: commonpb.StateCode_Healthy, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil }, CallGetStatisticsChannel: func(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil }, CallCreateJob: func(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil }, CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { indexInfos := make([]*indexpb.IndexTaskInfo, 0) @@ -111,23 +106,17 @@ func NewIndexNodeMock() *Mock { }) } return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), ClusterID: in.ClusterID, IndexInfos: indexInfos, }, nil }, CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil }, CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), TotalJobNum: 1, EnqueueJobNum: 0, InProgressJobNum: 1, @@ -148,9 +137,7 @@ func NewIndexNodeMock() *Mock { }, CallShowConfigurations: func(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), }, nil }, } @@ -252,10 +239,7 @@ func getMockSystemInfoMetrics( resp, _ := metricsinfo.MarshalComponentInfos(nodeInfos) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index 3f4560612d..88f367a9f0 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -103,10 +104,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.BuildID, req.ClusterID)), serializedSize: 0, } - ret := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - } + ret := merr.Status(nil) if err := i.sched.IndexBuildQueue.Enqueue(task); err != nil { log.Ctx(ctx).Warn("IndexNode failed to schedule", zap.Int64("IndexBuildID", req.BuildID), zap.String("ClusterID", req.ClusterID), zap.Error(err)) ret.ErrorCode = commonpb.ErrorCode_UnexpectedError @@ -146,10 +144,7 @@ func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest } }) ret := &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), ClusterID: req.ClusterID, IndexInfos: make([]*indexpb.IndexTaskInfo, 0, len(req.BuildIDs)), } @@ -196,10 +191,7 @@ func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) } log.Ctx(ctx).Info("drop index build jobs success", zap.String("ClusterID", req.ClusterID), zap.Int64s("IndexBuildIDs", req.BuildIDs)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { @@ -227,10 +219,7 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq } log.Ctx(ctx).Info("Get Index Job Stats", zap.Int("Unissued", unissued), zap.Int("Active", active), zap.Int("Slot", slots)) return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), TotalJobNum: int64(active) + int64(unissued), InProgressJobNum: int64(active), EnqueueJobNum: int64(unissued), diff --git a/internal/indexnode/metrics_info.go b/internal/indexnode/metrics_info.go index 100e0e815e..e917080559 100644 --- a/internal/indexnode/metrics_info.go +++ b/internal/indexnode/metrics_info.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -73,10 +74,7 @@ func getSystemInfoMetrics( } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index 2d5795b922..ea2b028ec5 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/zap" ) @@ -189,7 +190,7 @@ func (kc *Catalog) loadCollectionFromDb(ctx context.Context, dbID int64, collect collKey := BuildCollectionKey(dbID, collectionID) collVal, err := kc.Snapshot.Load(collKey, ts) if err != nil { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("collection not found: %d, error: %s", collectionID, err.Error())) + return nil, merr.WrapErrCollectionNotFound(collectionID, err.Error()) } collMeta := &pb.CollectionInfo{} @@ -592,7 +593,7 @@ func (kc *Catalog) GetCollectionByName(ctx context.Context, dbID int64, collecti } } - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection %d:%s, at timestamp = %d", dbID, collectionName, ts)) + return nil, merr.WrapErrCollectionNotFoundWithDB(dbID, collectionName, fmt.Sprintf("timestample = %d", ts)) } func (kc *Catalog) ListCollections(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Collection, error) { diff --git a/internal/proxy/default_limit_reducer.go b/internal/proxy/default_limit_reducer.go index dc04e72cb7..d74440567b 100644 --- a/internal/proxy/default_limit_reducer.go +++ b/internal/proxy/default_limit_reducer.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -40,9 +41,7 @@ func (r *defaultLimitReducer) afterReduce(result *milvuspb.QueryResults) error { result.CollectionName = collectionName if len(result.FieldsData) > 0 { - result.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } + result.Status = merr.Status(nil) } else { result.Status = &commonpb.Status{ ErrorCode: commonpb.ErrorCode_EmptyCollection, diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index f33cc753f0..da71472c85 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -70,9 +70,7 @@ func (node *Proxy) UpdateStateCode(code commonpb.StateCode) { // GetComponentStates get state of Proxy. func (node *Proxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { stats := &milvuspb.ComponentStates{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } code, ok := node.stateCode.Load().(commonpb.StateCode) if !ok { @@ -100,11 +98,8 @@ func (node *Proxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentS // GetStatisticsChannel gets statistics channel of Proxy. func (node *Proxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: "", + Status: merr.Status(nil), + Value: "", }, nil } @@ -148,10 +143,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p } log.Info("complete to invalidate collection meta cache") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { @@ -183,10 +175,7 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD log.Warn(rpcFailedToEnqueue(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info(rpcEnqueued(method)) @@ -194,10 +183,7 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info(rpcDone(method)) @@ -233,20 +219,14 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab if err := node.sched.ddQueue.Enqueue(dct); err != nil { log.Warn(rpcFailedToEnqueue(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info(rpcEnqueued(method)) if err := dct.WaitToFinish(); err != nil { log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info(rpcDone(method)) @@ -284,10 +264,7 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData if err := node.sched.ddQueue.Enqueue(dct); err != nil { log.Warn(rpcFailedToEnqueue(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - resp.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } + resp.Status = merr.Status(err) return resp, nil } @@ -295,10 +272,7 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData if err := dct.WaitToFinish(); err != nil { log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } + resp.Status = merr.Status(err) return resp, nil } @@ -348,10 +322,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -368,10 +339,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat zap.Uint64("EndTs", cct.EndTs())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -417,10 +385,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug("DropCollection enqueued", @@ -434,10 +399,7 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol zap.Uint64("EndTs", dct.EndTs())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug("DropCollection done", @@ -485,10 +447,7 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -505,10 +464,7 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -556,10 +512,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug("LoadCollection enqueued", @@ -573,10 +526,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol zap.Uint64("EndTS", lct.EndTs())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug("LoadCollection done", @@ -623,10 +573,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -643,10 +590,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -696,10 +640,7 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -717,10 +658,7 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des metrics.FailLabel).Inc() return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -728,7 +666,8 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des zap.Uint64("BeginTS", dct.BeginTs()), zap.Uint64("EndTS", dct.EndTs()), zap.String("db", request.DbName), - zap.String("collection", request.CollectionName)) + zap.String("collection", request.CollectionName), + ) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() @@ -780,10 +719,7 @@ func (node *Proxy) GetStatistics(ctx context.Context, request *milvuspb.GetStati metrics.AbandonLabel).Inc() return &milvuspb.GetStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -805,10 +741,7 @@ func (node *Proxy) GetStatistics(ctx context.Context, request *milvuspb.GetStati metrics.FailLabel).Inc() return &milvuspb.GetStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -860,10 +793,7 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp metrics.AbandonLabel).Inc() return &milvuspb.GetCollectionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -883,10 +813,7 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp metrics.FailLabel).Inc() return &milvuspb.GetCollectionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -939,10 +866,7 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() return &milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -958,10 +882,7 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1007,10 +928,7 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1027,10 +945,7 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC zap.Uint64("EndTs", act.EndTs())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1078,10 +993,7 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1098,10 +1010,7 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1150,10 +1059,7 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1170,10 +1076,7 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1227,11 +1130,8 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit metrics.AbandonLabel).Inc() return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - Value: false, + Status: merr.Status(err), + Value: false, }, nil } @@ -1251,11 +1151,8 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit metrics.FailLabel).Inc() return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - Value: false, + Status: merr.Status(err), + Value: false, }, nil } @@ -1307,10 +1204,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1328,10 +1222,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1382,10 +1273,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1403,10 +1291,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -1459,10 +1344,7 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb metrics.AbandonLabel).Inc() return &milvuspb.GetPartitionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1482,10 +1364,7 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb metrics.FailLabel).Inc() return &milvuspb.GetPartitionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1542,10 +1421,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar metrics.AbandonLabel).Inc() return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1571,10 +1447,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar metrics.FailLabel).Inc() return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1664,9 +1537,7 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.GetLoadingProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), Progress: loadProgress, RefreshProgress: refreshProgress, }, nil @@ -1711,9 +1582,7 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt } successResponse := &milvuspb.GetLoadStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } defer func() { log.Debug( @@ -1812,10 +1681,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info( @@ -1833,10 +1699,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info( @@ -1892,10 +1755,7 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe metrics.AbandonLabel).Inc() return &milvuspb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1979,10 +1839,7 @@ func (node *Proxy) GetIndexStatistics(ctx context.Context, request *milvuspb.Get metrics.AbandonLabel).Inc() return &milvuspb.GetIndexStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2056,10 +1913,7 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -2077,10 +1931,7 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -2137,10 +1988,7 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb. metrics.AbandonLabel).Inc() return &milvuspb.GetIndexBuildProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2159,10 +2007,7 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb. metrics.FailLabel).Inc() return &milvuspb.GetIndexBuildProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2220,10 +2065,7 @@ func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndex metrics.AbandonLabel).Inc() return &milvuspb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2242,10 +2084,7 @@ func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndex metrics.FailLabel).Inc() return &milvuspb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2321,10 +2160,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) } return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), ErrIndex: errIndex, } } @@ -2438,10 +2274,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) metrics.AbandonLabel).Inc() return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2452,10 +2285,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2509,9 +2339,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) Condition: NewTaskCondition(ctx), req: request, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IDs: &schemapb.IDs{ IdField: nil, }, @@ -2549,10 +2377,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2673,10 +2498,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) metrics.AbandonLabel).Inc() return &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } tr.CtxRecord(ctx, "search request enqueue") @@ -2694,10 +2516,7 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) metrics.FailLabel).Inc() return &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -2875,10 +2694,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* metrics.AbandonLabel).Inc() return &milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } tr.CtxRecord(ctx, "query request enqueue") @@ -2894,10 +2710,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* metrics.FailLabel).Inc() return &milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } span := tr.CtxRecord(ctx, "wait query result") @@ -2954,10 +2767,7 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -2973,10 +2783,7 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia zap.Uint64("EndTs", cat.EndTs())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -3040,10 +2847,7 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -3060,10 +2864,7 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -3110,10 +2911,7 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -3130,10 +2928,7 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Debug( @@ -3537,10 +3332,7 @@ func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReque zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", }, nil } @@ -3608,10 +3400,7 @@ func (node *Proxy) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetrics zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -3859,10 +3648,7 @@ func (node *Proxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushA resp, err = node.dataCoord.GetFlushAllState(ctx, req) if err != nil { - resp.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } + resp.Status = merr.Status(err) log.Warn("GetFlushAllState failed", zap.String("err", err.Error())) return resp, nil } @@ -3901,10 +3687,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi zap.String("partition name", req.GetPartitionName()), zap.Strings("files", req.GetFiles())) resp := &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), } if !node.checkHealthy() { resp.Status = unhealthyStatus() @@ -4033,10 +3816,7 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy } log.Debug("complete to invalidate credential cache") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } // UpdateCredentialCache update the credential cache of specified username. @@ -4062,10 +3842,7 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U } log.Debug("complete to update credential cache") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { @@ -4124,10 +3901,7 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre if err != nil { // for error like conntext timeout etc. log.Error("create credential fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, err } @@ -4206,10 +3980,7 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre if err != nil { // for error like conntext timeout etc. log.Error("update credential fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, err } @@ -4237,10 +4008,7 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre if err != nil { // for error like conntext timeout etc. log.Error("delete credential fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, err } @@ -4264,16 +4032,11 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser resp, err := node.rootCoord.ListCredUsers(ctx, rootCoordReq) if err != nil { return &milvuspb.ListCredUsersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } return &milvuspb.ListCredUsersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), Usernames: resp.Usernames, }, nil } @@ -4303,10 +4066,7 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque result, err := node.rootCoord.CreateRole(ctx, req) if err != nil { log.Warn("fail to create role", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, nil } @@ -4340,10 +4100,7 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) log.Warn("fail to drop role", zap.String("role_name", req.RoleName), zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, nil } @@ -4374,10 +4131,7 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse result, err := node.rootCoord.OperateUserRole(ctx, req) if err != nil { log.Warn("fail to operate user role", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, nil } @@ -4396,10 +4150,7 @@ func (node *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReque if req.Role != nil { if err := ValidateRoleName(req.Role.Name); err != nil { return &milvuspb.SelectRoleResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } } @@ -4408,10 +4159,7 @@ func (node *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReque if err != nil { log.Warn("fail to select role", zap.Error(err)) return &milvuspb.SelectRoleResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } return result, nil @@ -4432,10 +4180,7 @@ func (node *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReque if err := ValidateUsername(req.User.Name); err != nil { log.Warn("invalid username", zap.Error(err)) return &milvuspb.SelectUserResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } } @@ -4444,10 +4189,7 @@ func (node *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReque if err != nil { log.Warn("fail to select user", zap.Error(err)) return &milvuspb.SelectUserResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } return result, nil @@ -4514,10 +4256,7 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr result, err := node.rootCoord.OperatePrivilege(ctx, req) if err != nil { log.Warn("fail to operate privilege", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } return result, nil } @@ -4573,10 +4312,7 @@ func (node *Proxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReq if err != nil { log.Warn("fail to select grant", zap.Error(err)) return &milvuspb.SelectGrantResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } return result, nil @@ -4610,9 +4346,7 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr } log.Debug("RefreshPrivilegeInfoCache success") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // SetRates limits the rates of requests. @@ -4689,9 +4423,7 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt err := group.Wait() if err != nil || len(errReasons) != 0 { return &milvuspb.CheckHealthResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IsHealthy: false, Reasons: errReasons, }, nil @@ -4699,10 +4431,7 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt states, reasons := node.multiRateLimiter.GetQuotaStates() return &milvuspb.CheckHealthResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), QuotaStates: states, Reasons: reasons, IsHealthy: true, @@ -5038,10 +4767,7 @@ func (node *Proxy) ListResourceGroups(ctx context.Context, request *milvuspb.Lis metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() return &milvuspb.ListResourceGroupsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -5057,10 +4783,7 @@ func (node *Proxy) ListResourceGroups(ctx context.Context, request *milvuspb.Lis metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.ListResourceGroupsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -5086,10 +4809,7 @@ func (node *Proxy) DescribeResourceGroup(ctx context.Context, request *milvuspb. metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.DescribeResourceGroupResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), } } @@ -5178,10 +4898,7 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest if err != nil { log.Info("connect failed, failed to list databases", zap.Error(err)) return &milvuspb.ConnectResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -5208,10 +4925,7 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest if err != nil { log.Info("connect failed, failed to allocate timestamp", zap.Error(err)) return &milvuspb.ConnectResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -5227,7 +4941,7 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest GetConnectionManager().register(ctx, int64(ts), request.GetClientInfo()) return &milvuspb.ConnectResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), ServerInfo: serverInfo, Identifier: int64(ts), }, nil @@ -5241,7 +4955,7 @@ func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientI clients := GetConnectionManager().list() return &proxypb.ListClientInfosResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), ClientInfos: clients, }, nil } @@ -5256,17 +4970,14 @@ func (node *Proxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimest if err != nil { log.Info("AllocTimestamp failed", zap.Error(err)) return &milvuspb.AllocTimestampResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } log.Info("AllocTimestamp request success", zap.Uint64("timestamp", ts)) return &milvuspb.AllocTimestampResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Timestamp: ts, }, nil } diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index fbd6550634..060664d639 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -18,6 +18,7 @@ package proxy import ( "context" + "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -167,7 +168,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo // cancel work load which assign to the target node lb.balancer.CancelWorkload(targetNode, workload.nq) - return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error()) + return errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel) } err = workload.exec(ctx, targetNode, client, workload.channel) @@ -178,11 +179,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo excludeNodes.Insert(targetNode) lb.balancer.CancelWorkload(targetNode, workload.nq) - if err == context.Canceled || err == context.DeadlineExceeded { - return merr.WrapErrShardDelegatorSQTimeout(workload.channel, err.Error()) - } - - return merr.WrapErrShardDelegatorSQFailed(workload.channel, err.Error()) + return errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel) } lb.balancer.CancelWorkload(targetNode, workload.nq) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 5ddb438461..58b156dff9 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -360,7 +360,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { }, retryTimes: 2, }) - s.ErrorIs(err, merr.ErrShardDelegatorSQTimeout) + s.True(merr.IsCanceledOrTimeout(err)) } func (s *LBPolicySuite) TestExecute() { diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 7ba88e316c..6e387b7e87 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -579,8 +579,9 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection if err != nil { return nil, err } - if coll.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, common.NewStatusError(coll.GetStatus().GetErrorCode(), coll.GetStatus().GetReason()) + err = merr.Error(coll.GetStatus()) + if err != nil { + return nil, err } resp := &milvuspb.DescribeCollectionResponse{ Status: coll.Status, diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 847de26d5a..d37c5b38be 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -155,12 +156,9 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i }, nil } - err := fmt.Errorf("can't find collection: " + in.CollectionName) + err := merr.WrapErrCollectionNotFound(in.CollectionName) return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNotExists, - Reason: "describe collection failed: " + err.Error(), - }, + Status: merr.Status(err), Schema: nil, }, nil } diff --git a/internal/proxy/metrics_info.go b/internal/proxy/metrics_info.go index c22178f6cf..f4e7021559 100644 --- a/internal/proxy/metrics_info.go +++ b/internal/proxy/metrics_info.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" @@ -107,9 +108,7 @@ func getProxyMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest, n } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.ProxyRole, paramtable.GetNodeID()), }, nil @@ -429,10 +428,7 @@ func getSystemInfoMetrics( } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.ProxyRole, paramtable.GetNodeID()), }, nil diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 1270cd2392..de243d236d 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -503,9 +503,7 @@ func (dct *describeCollectionTask) PreExecute(ctx context.Context) error { func (dct *describeCollectionTask) Execute(ctx context.Context) error { var err error dct.result = &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), Schema: &schemapb.CollectionSchema{ Name: "", Description: "", @@ -526,6 +524,13 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error { if result.Status.ErrorCode != commonpb.ErrorCode_Success { dct.result.Status = result.Status + + // compatibility with PyMilvus existing implementation + err := merr.Error(dct.result.GetStatus()) + if errors.Is(err, merr.ErrCollectionNotFound) { + dct.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + dct.result.Status.Reason = "can't find collection " + dct.result.Status.Reason + } } else { dct.result.Schema.Name = result.Schema.Name dct.result.Schema.Description = result.Schema.Description @@ -1306,10 +1311,7 @@ func (ft *flushTask) Execute(ctx context.Context) error { coll2SealTimes[collName] = resp.GetTimeOfSeal() } ft.result = &milvuspb.FlushResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), DbName: ft.GetDbName(), CollSegIDs: coll2Segments, FlushCollSegIDs: flushColl2Segments, @@ -2218,7 +2220,7 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error { zap.Error(err)) // if collection has been dropped, skip it - if common.IsCollectionNotExistError(err) { + if errors.Is(err, merr.ErrCollectionNotFound) { continue } return nil, err diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 67c188493f..ba804a79da 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -154,9 +155,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { dt.deleteMsg.Base.SourceID = paramtable.GetNodeID() dt.result = &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IDs: &schemapb.IDs{ IdField: nil, }, diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 4dd328ce03..3512c79c83 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -915,10 +915,7 @@ func (gist *getIndexStateTask) Execute(ctx context.Context) error { } gist.result = &milvuspb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), State: state.GetState(), FailReason: state.GetFailReason(), } diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 455518d382..b3ce3ba6b9 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -97,9 +98,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { defer sp.End() it.result = &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IDs: &schemapb.IDs{ IdField: nil, }, diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 702e65d3a0..e1c642e8a4 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -24,7 +24,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -415,7 +414,7 @@ func (t *queryTask) Execute(ctx context.Context) error { }) if err != nil { log.Warn("fail to execute query", zap.Error(err)) - return merr.WrapErrShardDelegatorQueryFailed(err.Error()) + return errors.Wrap(err, "failed to query") } log.Debug("Query Execute done.") diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index d52e644376..1265f0fd81 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -420,7 +420,7 @@ func (t *searchTask) Execute(ctx context.Context) error { }) if err != nil { log.Warn("search execute failed", zap.Error(err)) - return merr.WrapErrShardDelegatorSearchFailed(err.Error()) + return errors.Wrap(err, "failed to search") } log.Debug("Search Execute done.", @@ -770,9 +770,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), Results: &schemapb.SearchResultData{ NumQueries: nq, TopK: topk, diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index d536bf4365..efe7a42a5a 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -216,7 +216,7 @@ func (g *getStatisticsTask) PostExecute(ctx context.Context) error { return err } g.result = &milvuspb.GetStatisticsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Stats: result, } @@ -248,7 +248,7 @@ func (g *getStatisticsTask) getStatisticsFromDataCoord(ctx context.Context) erro g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]() } g.resultBuf.Insert(&internalpb.GetStatisticsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Stats: result.Stats, }) return nil @@ -268,7 +268,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro }) if err != nil { - return merr.WrapErrShardDelegatorStatisticFailed(err.Error()) + return errors.Wrap(err, "failed to statistic") } return nil @@ -466,7 +466,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return errors.New(result.Status.Reason) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, +// Status: merr.Status(nil), // Stats: result.Stats, // }) // log.Debug("get partition statistics from DataCoord execute done", zap.Int64("msgID", g.ID())) @@ -481,7 +481,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return err // } // g.result = &milvuspb.GetPartitionStatisticsResponse{ -// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, +// Status: merr.Status(nil), // Stats: g.innerResult, // } // return nil @@ -538,7 +538,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return errors.New(result.Status.Reason) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, +// Status: merr.Status(nil), // Stats: result.Stats, // }) // } else { // some partitions have been loaded, get some partition statistics from datacoord @@ -561,7 +561,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return errors.New(result.Status.Reason) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, +// Status: merr.Status(nil), // Stats: result.Stats, // }) // } @@ -577,7 +577,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return err // } // g.result = &milvuspb.GetCollectionStatisticsResponse{ -// Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, +// Status: merr.Status(nil), // Stats: g.innerResult, // } // return nil @@ -660,11 +660,8 @@ func (g *getCollectionStatisticsTask) Execute(ctx context.Context) error { return errors.New(result.Status.Reason) } g.result = &milvuspb.GetCollectionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Stats: result.Stats, + Status: merr.Status(nil), + Stats: result.Stats, } return nil } @@ -753,11 +750,8 @@ func (g *getPartitionStatisticsTask) Execute(ctx context.Context) error { return errors.New(result.Status.Reason) } g.result = &milvuspb.GetPartitionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Stats: result.Stats, + Status: merr.Status(nil), + Stats: result.Stats, } return nil } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index aa047cc726..355b8ef79f 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1273,7 +1273,7 @@ func TestDropPartitionTask(t *testing.T) { PartitionIDs: []int64{}, }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), }, nil) mockCache := NewMockCache(t) @@ -2554,7 +2554,7 @@ func Test_dropCollectionTask_Execute(t *testing.T) { case "c1": return errors.New("error mock DropCollection") case "c2": - return common.NewStatusError(commonpb.ErrorCode_CollectionNotExists, "collection not exist") + return merr.WrapErrCollectionNotFound("mock") default: return nil } @@ -2594,7 +2594,7 @@ func Test_loadCollectionTask_Execute(t *testing.T) { PartitionIDs: []int64{}, }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), }, nil) dbName := funcutil.GenRandomStr() @@ -2702,7 +2702,7 @@ func Test_loadPartitionTask_Execute(t *testing.T) { PartitionIDs: []int64{}, }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), }, nil) dbName := funcutil.GenRandomStr() @@ -2802,7 +2802,7 @@ func TestCreateResourceGroupTask(t *testing.T) { rc.Start() defer rc.Stop() qc := getQueryCoord() - qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) qc.Start() defer qc.Stop() ctx := context.Background() @@ -2842,7 +2842,7 @@ func TestDropResourceGroupTask(t *testing.T) { rc.Start() defer rc.Stop() qc := getQueryCoord() - qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) qc.Start() defer qc.Stop() ctx := context.Background() @@ -2882,7 +2882,7 @@ func TestTransferNodeTask(t *testing.T) { rc.Start() defer rc.Stop() qc := getQueryCoord() - qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) qc.Start() defer qc.Stop() ctx := context.Background() @@ -2922,7 +2922,7 @@ func TestTransferNodeTask(t *testing.T) { func TestTransferReplicaTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoord() - qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) qc.Start() defer qc.Stop() ctx := context.Background() @@ -2966,7 +2966,7 @@ func TestListResourceGroupsTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoord() qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"}, }, nil) qc.Start() @@ -3009,7 +3009,7 @@ func TestDescribeResourceGroupTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoord() qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), ResourceGroup: &querypb.ResourceGroupInfo{ Name: "rg", Capacity: 2, @@ -3105,7 +3105,7 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) { qc.ExpectedCalls = nil qc.EXPECT().Stop().Return(nil) qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), ResourceGroup: &querypb.ResourceGroupInfo{ Name: "rg", Capacity: 2, diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 26ac27afe0..64aec5660b 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -275,9 +276,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { log := log.Ctx(ctx).With(zap.String("collectionName", collectionName)) it.result = &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), IDs: &schemapb.IDs{ IdField: nil, }, diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 07906b7577..d95db4cb48 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -742,26 +742,27 @@ func validateName(entity string, nameType string) error { entity = strings.TrimSpace(entity) if entity == "" { - return fmt.Errorf("%s should not be empty", nameType) + return merr.WrapErrParameterInvalid("not empty", entity, nameType+" should be not empty") } - invalidMsg := fmt.Sprintf("invalid %s: %s. ", nameType, entity) if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() { - msg := invalidMsg + fmt.Sprintf("the length of %s must be less than ", nameType) + Params.ProxyCfg.MaxNameLength.GetValue() + " characters." - return errors.New(msg) + return merr.WrapErrParameterInvalidRange(0, + Params.ProxyCfg.MaxNameLength.GetAsInt(), + len(entity), + fmt.Sprintf("the length of %s must be not greater than limit", nameType)) } firstChar := entity[0] if firstChar != '_' && !isAlpha(firstChar) { - msg := invalidMsg + fmt.Sprintf("the first character of %s must be an underscore or letter.", nameType) - return errors.New(msg) + return merr.WrapErrParameterInvalid('_', + firstChar, + fmt.Sprintf("the first character of %s must be an underscore or letter", nameType)) } for i := 1; i < len(entity); i++ { c := entity[i] if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) { - msg := invalidMsg + fmt.Sprintf("%s can only contain numbers, letters, dollars and underscores.", nameType) - return errors.New(msg) + return merr.WrapErrParameterInvalidMsg("%s can only contain numbers, letters, dollars and underscores, found %c at %d", nameType, c, i) } } return nil @@ -1035,7 +1036,7 @@ func fillFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgstr for _, data := range insertMsg.FieldsData { fieldName := data.GetFieldName() if dataNameSet.Contain(fieldName) { - return merr.WrapErrParameterDuplicateFieldData(fieldName, "The FieldDatas parameter being passed contains duplicate data for a field.") + return merr.WrapErrParameterInvalidMsg("The FieldDatas parameter being passed contains duplicate data for field %s", fieldName) } dataNameSet.Insert(fieldName) } diff --git a/internal/querycoordv2/meta/coordinator_broker.go b/internal/querycoordv2/meta/coordinator_broker.go index 74141ca3b0..ea2324eb11 100644 --- a/internal/querycoordv2/meta/coordinator_broker.go +++ b/internal/querycoordv2/meta/coordinator_broker.go @@ -31,7 +31,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -79,14 +78,9 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec return nil, err } - statusErr := common.NewStatusError(resp.Status.ErrorCode, resp.Status.Reason) - if common.IsCollectionNotExistError(statusErr) { - return nil, merr.WrapErrCollectionNotFound(collectionID) - } - - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - err = errors.New(resp.GetStatus().GetReason()) - log.Warn("failed to get collection schema", zap.Int64("collectionID", collectionID), zap.Error(err)) + err = merr.Error(resp.GetStatus()) + if err != nil { + log.Warn("failed to get collection schema", zap.Error(err)) return nil, err } return resp.GetSchema(), nil @@ -108,14 +102,9 @@ func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID return nil, err } - statusErr := common.NewStatusError(resp.Status.ErrorCode, resp.Status.Reason) - if common.IsCollectionNotExistError(statusErr) { - return nil, merr.WrapErrCollectionNotFound(collectionID) - } - - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - err = errors.New(resp.GetStatus().GetReason()) - log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) + err = merr.Error(resp.GetStatus()) + if err != nil { + log.Warn("failed to get partitions", zap.Error(err)) return nil, err } diff --git a/internal/querycoordv2/meta/coordinator_broker_test.go b/internal/querycoordv2/meta/coordinator_broker_test.go index 66994965a6..3a42b72942 100644 --- a/internal/querycoordv2/meta/coordinator_broker_test.go +++ b/internal/querycoordv2/meta/coordinator_broker_test.go @@ -138,9 +138,7 @@ func TestCoordinatorBroker_GetPartitions(t *testing.T) { t.Run("collection not exist", func(t *testing.T) { rc := mocks.NewRootCoord(t) rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNotExists, - }, + Status: merr.Status(merr.WrapErrCollectionNotFound("mock")), }, nil) ctx := context.Background() diff --git a/internal/querycoordv2/mocks/querynode.go b/internal/querycoordv2/mocks/querynode.go index a326523816..013b267db1 100644 --- a/internal/querycoordv2/mocks/querynode.go +++ b/internal/querycoordv2/mocks/querynode.go @@ -28,12 +28,12 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -82,9 +82,7 @@ func (node *MockQueryNode) Start() error { err = node.server.Serve(lis) }() - successStatus := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } + successStatus := merr.Status(nil) node.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{ Status: successStatus, NodeID: node.ID, diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 697b712d25..d8609c18eb 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -52,6 +52,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -510,30 +511,22 @@ func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta } return &milvuspb.ComponentStates{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - State: serviceComponentInfo, + Status: merr.Status(nil), + State: serviceComponentInfo, //SubcomponentStates: subComponentInfos, }, nil } func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), }, nil } func (s *Server) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: Params.CommonCfg.QueryCoordTimeTick.GetValue(), + Status: merr.Status(nil), + Value: Params.CommonCfg.QueryCoordTimeTick.GetValue(), }, nil } diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 7044b9520b..c81dce13e1 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -727,10 +727,7 @@ func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Configuations: configList, }, nil } diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 2a6edac867..7c62ee9b9c 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -251,7 +251,7 @@ func (sd *shardDelegator) applyDelete(ctx context.Context, nodeID int64, worker log.Debug("delegator plan to applyDelete via worker") err := retry.Do(ctx, func() error { if sd.Stopped() { - return retry.Unrecoverable(merr.WrapErrChannelUnsubscribing(sd.vchannelName)) + return retry.Unrecoverable(merr.WrapErrChannelNotAvailable(sd.vchannelName, "channel is unsubscribing")) } err := worker.Delete(ctx, &querypb.DeleteRequest{ diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index b943d01453..0dab6bbcfa 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -440,7 +440,7 @@ func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStati resultMap["row_count"] = strconv.FormatInt(totalRowNum, 10) ret := &internalpb.GetStatisticsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Stats: funcutil.Map2KeyValuePair(resultMap), } return ret @@ -479,7 +479,7 @@ func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*inte } ret := &internalpb.GetStatisticsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Stats: funcutil.Map2KeyValuePair(stringMap), } return ret, nil diff --git a/internal/querynodev2/metrics_info.go b/internal/querynodev2/metrics_info.go index 7f7d192a81..8f7b6f2a54 100644 --- a/internal/querynodev2/metrics_info.go +++ b/internal/querynodev2/metrics_info.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" @@ -210,10 +211,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/querynodev2/pipeline/manager.go b/internal/querynodev2/pipeline/manager.go index f864559c6d..1d486e9fc3 100644 --- a/internal/querynodev2/pipeline/manager.go +++ b/internal/querynodev2/pipeline/manager.go @@ -80,7 +80,7 @@ func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) { //get shard delegator for add growing in pipeline delegator, ok := m.delegators.Get(channel) if !ok { - return nil, merr.WrapErrShardDelegatorNotFound(channel) + return nil, merr.WrapErrChannelNotFound(channel, "delegator not found") } newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.tSafeManager, m.dispatcher, delegator) diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 863a108969..9703ab0e03 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -25,13 +25,13 @@ import ( "github.com/samber/lo" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -222,9 +222,7 @@ func DecodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb func EncodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int64, topk int64, metricType string) (searchResults *internalpb.SearchResults, err error) { searchResults = &internalpb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), NumQueries: nq, TopK: topk, MetricType: metricType, diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 62162ce182..fa5bf3e472 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -58,9 +58,7 @@ import ( // GetComponentStates returns information about whether the node is healthy func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { stats := &milvuspb.ComponentStates{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } code := node.lifetime.GetState() @@ -82,11 +80,8 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon // TimeTickChannel contains many time tick messages, which will be sent by query nodes func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: paramtable.Get().CommonCfg.QueryCoordTimeTick.GetValue(), + Status: merr.Status(nil), + Value: paramtable.Get().CommonCfg.QueryCoordTimeTick.GetValue(), }, nil } @@ -94,10 +89,7 @@ func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.String // Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), }, nil } @@ -132,9 +124,7 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis }, nil } failRet := &internalpb.GetStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } var toReduceResults []*internalpb.GetStatisticsResponse @@ -256,8 +246,8 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm // to avoid concurrent watch/unwatch if node.unsubscribingChannels.Contain(channel.GetChannelName()) { - err := merr.WrapErrChannelUnsubscribing(channel.GetChannelName()) - log.Warn("abort watch unsubscribing channel", zap.Error(err)) + err := merr.WrapErrChannelReduplicate(channel.GetChannelName(), "the other same channel is unsubscribing") + log.Warn("failed to unsubscribe channel", zap.Error(err)) return merr.Status(err), nil } @@ -680,10 +670,8 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen } return &querypb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Infos: segmentInfos, + Status: merr.Status(nil), + Infos: segmentInfos, }, nil } @@ -798,9 +786,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( } failRet := &internalpb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Status(nil), } collection := node.manager.Collection.Get(req.GetReq().GetCollectionID()) if collection == nil { @@ -1075,10 +1061,7 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S } return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Configuations: configList, }, nil } @@ -1239,7 +1222,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get }) return &querypb.GetDataDistributionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), NodeID: paramtable.GetNodeID(), Segments: segmentVersionInfos, Channels: channelVersionInfos, @@ -1346,10 +1329,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi }, true) } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } // Delete is used to forward delete message between delegator and workers. @@ -1406,8 +1386,5 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( } } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Status(nil), nil } diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 5763cc6aba..372d01076f 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -377,7 +377,7 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { suite.node.unsubscribingChannels.Insert(suite.vchannel) status, err := suite.node.WatchDmChannels(ctx, req) suite.NoError(err) - suite.Equal(status.GetReason(), merr.WrapErrChannelUnsubscribing(suite.vchannel).Error()) + suite.ErrorIs(merr.Error(status), merr.ErrChannelReduplicate) suite.node.unsubscribingChannels.Remove(suite.vchannel) // init msgstream failed diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index cc961880f1..5b1f658195 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -130,7 +131,7 @@ func (t *QueryTask) Execute() error { Base: &commonpb.MsgBase{ SourceID: paramtable.GetNodeID(), }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), Ids: reducedResult.Ids, FieldsData: reducedResult.FieldsData, CostAggregation: &internalpb.CostAggregation{ diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index 8ace754e55..96a37fe2de 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -166,7 +167,7 @@ func (t *SearchTask) Execute() error { Base: &commonpb.MsgBase{ SourceID: paramtable.GetNodeID(), }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(nil), MetricType: req.GetReq().GetMetricType(), NumQueries: t.originNqs[i], TopK: t.originTopks[i], diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index 9770ae785d..5440186d38 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -22,11 +22,12 @@ import ( "go.uber.org/zap" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -55,7 +56,7 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { // dropping collection with `ts1` but a collection exists in catalog with newer ts which is bigger than `ts1`. // fortunately, if ddls are promised to execute in sequence, then everything is OK. The `ts1` will always be latest. collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp) - if common.IsCollectionNotExistError(err) { + if errors.Is(err, merr.ErrCollectionNotFound) { // make dropping collection idempotent. log.Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName())) return nil diff --git a/internal/rootcoord/drop_collection_task_test.go b/internal/rootcoord/drop_collection_task_test.go index 1e5162c4c2..543c59b58e 100644 --- a/internal/rootcoord/drop_collection_task_test.go +++ b/internal/rootcoord/drop_collection_task_test.go @@ -29,8 +29,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" ) func Test_dropCollectionTask_Prepare(t *testing.T) { @@ -98,7 +98,7 @@ func Test_dropCollectionTask_Execute(t *testing.T) { mock.Anything, ).Return(nil, func(ctx context.Context, dbName string, name string, ts Timestamp) error { if collectionName == name { - return common.NewCollectionNotExistError("collection not exist") + return merr.WrapErrCollectionNotFound(collectionName) } return errors.New("error mock GetCollectionByName") }) diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index e80681a071..c48d4cd25a 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/importutil" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -436,10 +437,8 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque } resp := &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Tasks: make([]int64, 0), + Status: merr.Status(nil), + Tasks: make([]int64, 0), } log.Info("receive import job", @@ -735,9 +734,7 @@ func (m *importManager) setCollectionPartitionName(dbName string, colID, partID } func (m *importManager) copyTaskInfo(input *datapb.ImportTaskInfo, output *milvuspb.GetImportStateResponse) { - output.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } + output.Status = merr.Status(nil) output.Id = input.GetId() output.CollectionId = input.GetCollectionId() diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index c4d7e06835..aa90d14a0c 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -493,13 +494,13 @@ func filterUnavailable(coll *model.Collection) *model.Collection { func (mt *MetaTable) getLatestCollectionByIDInternal(ctx context.Context, collectionID UniqueID, allowAvailable bool) (*model.Collection, error) { coll, ok := mt.collID2Meta[collectionID] if !ok || coll == nil { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %d", collectionID)) + return nil, merr.WrapErrCollectionNotFound(collectionID) } if allowAvailable { return coll.Clone(), nil } if !coll.Available() { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %d", collectionID)) + return nil, merr.WrapErrCollectionNotFound(collectionID) } return filterUnavailable(coll), nil } @@ -527,7 +528,7 @@ func (mt *MetaTable) getCollectionByIDInternal(ctx context.Context, dbName strin if coll == nil { // use coll.Name to match error message of regression. TODO: remove this after error code is ready. - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection: %d", collectionID)) + return nil, merr.WrapErrCollectionNotFound(collectionID) } if allowUnavailable { @@ -536,7 +537,7 @@ func (mt *MetaTable) getCollectionByIDInternal(ctx context.Context, dbName strin if !coll.Available() { // use coll.Name to match error message of regression. TODO: remove this after error code is ready. - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection %s:%s", dbName, coll.Name)) + return nil, merr.WrapErrCollectionNotFound(dbName, coll.Name) } return filterUnavailable(coll), nil @@ -566,7 +567,7 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str } if isMaxTs(ts) { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection %s:%s", dbName, collectionName)) + return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } db, err := mt.getDatabaseByNameInternal(ctx, dbName, typeutil.MaxTimestamp) @@ -582,7 +583,7 @@ func (mt *MetaTable) getCollectionByNameInternal(ctx context.Context, dbName str } if coll == nil || !coll.Available() { - return nil, common.NewCollectionNotExistError(fmt.Sprintf("can't find collection %s:%s", dbName, collectionName)) + return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } return filterUnavailable(coll), nil } @@ -742,7 +743,7 @@ func (mt *MetaTable) RenameCollection(ctx context.Context, dbName string, oldNam log.Warn("check new collection fail") return fmt.Errorf("duplicated new collection name %s:%s with other collection name or alias", newDBName, newName) } - if err != nil && !common.IsCollectionNotExistErrorV2(err) { + if err != nil && !errors.Is(err, merr.ErrCollectionNotFound) { log.Warn("check new collection name fail") return err } diff --git a/internal/rootcoord/meta_table_test.go b/internal/rootcoord/meta_table_test.go index b128d998bc..aaba6dad47 100644 --- a/internal/rootcoord/meta_table_test.go +++ b/internal/rootcoord/meta_table_test.go @@ -34,8 +34,8 @@ import ( pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/internalpb" mocktso "github.com/milvus-io/milvus/internal/tso/mocks" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -463,7 +463,7 @@ func TestMetaTable_getCollectionByIDInternal(t *testing.T) { ctx := context.Background() _, err := meta.getCollectionByIDInternal(ctx, util.DefaultDBName, 100, 101, false) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) coll, err := meta.getCollectionByIDInternal(ctx, util.DefaultDBName, 100, 101, true) assert.NoError(t, err) assert.False(t, coll.Available()) @@ -602,7 +602,7 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { ctx := context.Background() _, err := meta.GetCollectionByName(ctx, util.DefaultDBName, "name", 101) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) }) t.Run("normal case, filter unavailable partitions", func(t *testing.T) { @@ -642,7 +642,7 @@ func TestMetaTable_GetCollectionByName(t *testing.T) { meta := &MetaTable{names: newNameDb(), aliases: newNameDb()} _, err := meta.GetCollectionByName(ctx, "", "not_exist", typeutil.MaxTimestamp) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) }) } @@ -715,7 +715,7 @@ func TestMetaTable_getLatestCollectionByIDInternal(t *testing.T) { mt := &MetaTable{collID2Meta: nil} _, err := mt.getLatestCollectionByIDInternal(ctx, 100, false) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) }) t.Run("nil case", func(t *testing.T) { @@ -725,7 +725,7 @@ func TestMetaTable_getLatestCollectionByIDInternal(t *testing.T) { }} _, err := mt.getLatestCollectionByIDInternal(ctx, 100, false) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) }) t.Run("unavailable", func(t *testing.T) { @@ -735,7 +735,7 @@ func TestMetaTable_getLatestCollectionByIDInternal(t *testing.T) { }} _, err := mt.getLatestCollectionByIDInternal(ctx, 100, false) assert.Error(t, err) - assert.True(t, common.IsCollectionNotExistError(err)) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) coll, err := mt.getLatestCollectionByIDInternal(ctx, 100, true) assert.NoError(t, err) assert.False(t, coll.Available()) @@ -1258,7 +1258,7 @@ func TestMetaTable_RenameCollection(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, common.NewCollectionNotExistError("error")) + ).Return(nil, merr.WrapErrCollectionNotFound("error")) meta := &MetaTable{ dbName2Meta: map[string]*model.Database{ @@ -1286,7 +1286,7 @@ func TestMetaTable_RenameCollection(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, common.NewCollectionNotExistError("error")) + ).Return(nil, merr.WrapErrCollectionNotFound("error")) meta := &MetaTable{ dbName2Meta: map[string]*model.Database{ util.DefaultDBName: model.NewDefaultDatabase(), @@ -1323,7 +1323,7 @@ func TestMetaTable_RenameCollection(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, - ).Return(nil, common.NewCollectionNotExistError("error")) + ).Return(nil, merr.WrapErrCollectionNotFound("error")) meta := &MetaTable{ dbName2Meta: map[string]*model.Database{ util.DefaultDBName: model.NewDefaultDatabase(), diff --git a/internal/rootcoord/metrics_info.go b/internal/rootcoord/metrics_info.go index c47b1bbba6..d32e887c57 100644 --- a/internal/rootcoord/metrics_info.go +++ b/internal/rootcoord/metrics_info.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -78,10 +79,7 @@ func (c *Core) getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetric } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.RootCoordRole, c.session.ServerID), }, nil diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 53fabac3de..afc7e4af7d 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1111,19 +1111,15 @@ func (c *Core) describeCollectionImpl(ctx context.Context, in *milvuspb.Describe log.Info("failed to enqueue request to describe collection", zap.Error(err)) metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc() return &milvuspb.DescribeCollectionResponse{ - // TODO: use commonpb.ErrorCode_CollectionNotExists. SDK use commonpb.ErrorCode_UnexpectedError now. Status: merr.Status(err), - // Status: common.StatusFromError(err), }, nil } if err := t.WaitToFinish(); err != nil { - log.Info("failed to describe collection", zap.Error(err)) + log.Warn("failed to describe collection", zap.Error(err)) metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc() return &milvuspb.DescribeCollectionResponse{ - // TODO: use commonpb.ErrorCode_CollectionNotExists. SDK use commonpb.ErrorCode_UnexpectedError now. Status: merr.Status(err), - // Status: common.StatusFromError(err), }, nil } @@ -1581,10 +1577,7 @@ func (c *Core) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfi } return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Status(nil), Configuations: configList, }, nil } @@ -1988,9 +1981,7 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( } } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Status(nil), nil } // ExpireCredCache will call invalidate credential cache diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index 13a91baa71..b1c3987669 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -114,10 +115,7 @@ func failStatus(code commonpb.ErrorCode, reason string) *commonpb.Status { } func succStatus() *commonpb.Status { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - } + return merr.Status(nil) } type TimeTravelRequest interface { diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 5f5d7353e4..7dc8a12685 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -457,5 +457,5 @@ func IsCrossClusterRoutingErr(err error) bool { func IsServerIDMismatchErr(err error) bool { // GRPC utilizes `status.Status` to encapsulate errors, // hence it is not viable to employ the `errors.Is` for assessment. - return strings.Contains(err.Error(), merr.ErrServerIDMismatch.Error()) + return strings.Contains(err.Error(), merr.ErrNodeNotMatch.Error()) } diff --git a/internal/util/mock/grpc_rootcoord_client.go b/internal/util/mock/grpc_rootcoord_client.go index 7f6d3b2731..b5bad586a5 100644 --- a/internal/util/mock/grpc_rootcoord_client.go +++ b/internal/util/mock/grpc_rootcoord_client.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" ) var _ rootcoordpb.RootCoordClient = &GrpcRootCoordClient{} @@ -48,7 +49,7 @@ func (m *GrpcRootCoordClient) ListDatabases(ctx context.Context, in *milvuspb.Li } func (m *GrpcRootCoordClient) RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil + return merr.Status(nil), nil } func (m *GrpcRootCoordClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { diff --git a/pkg/common/error.go b/pkg/common/error.go index 892bd5490a..81d815160e 100644 --- a/pkg/common/error.go +++ b/pkg/common/error.go @@ -18,10 +18,8 @@ package common import ( "fmt" - "strings" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) var ( @@ -76,68 +74,3 @@ type KeyNotExistError struct { func (k *KeyNotExistError) Error() string { return fmt.Sprintf("there is no value on key = %s", k.key) } - -type statusError struct { - commonpb.Status -} - -func (e *statusError) Error() string { - return fmt.Sprintf("code: %s, reason: %s", e.GetErrorCode().String(), e.GetReason()) -} - -func NewStatusError(code commonpb.ErrorCode, reason string) *statusError { - return &statusError{Status: commonpb.Status{ErrorCode: code, Reason: reason}} -} - -func IsStatusError(e error) bool { - _, ok := e.(*statusError) - return ok -} - -var ( - // static variable, save temporary memory. - collectionNotExistCodes = []commonpb.ErrorCode{ - commonpb.ErrorCode_UnexpectedError, // TODO: remove this after SDK remove this dependency. - commonpb.ErrorCode_CollectionNotExists, - } -) - -func NewCollectionNotExistError(msg string) *statusError { - return NewStatusError(commonpb.ErrorCode_CollectionNotExists, msg) -} - -func IsCollectionNotExistError(e error) bool { - statusError, ok := e.(*statusError) - if !ok { - return false - } - // cycle import: common -> funcutil -> types -> sessionutil -> common - // return funcutil.SliceContain(collectionNotExistCodes, statusError.GetErrorCode()) - if statusError.Status.ErrorCode == commonpb.ErrorCode_CollectionNotExists { - return true - } - - if (statusError.Status.ErrorCode == commonpb.ErrorCode_UnexpectedError) && strings.Contains(statusError.Status.Reason, "can't find collection") { - return true - } - return false -} - -func IsCollectionNotExistErrorV2(e error) bool { - statusError, ok := e.(*statusError) - if !ok { - return false - } - return statusError.GetErrorCode() == commonpb.ErrorCode_CollectionNotExists -} - -func StatusFromError(e error) *commonpb.Status { - if e == nil { - return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - } - statusError, ok := e.(*statusError) - if !ok { - return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: e.Error()} - } - return &commonpb.Status{ErrorCode: statusError.GetErrorCode(), Reason: statusError.GetReason()} -} diff --git a/pkg/common/error_test.go b/pkg/common/error_test.go index d9b7e239c6..d35b3fc64e 100644 --- a/pkg/common/error_test.go +++ b/pkg/common/error_test.go @@ -17,11 +17,9 @@ package common import ( - "strings" "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" ) @@ -37,45 +35,3 @@ func TestNotExistError(t *testing.T) { assert.Equal(t, false, IsKeyNotExistError(err)) assert.Equal(t, true, IsKeyNotExistError(NewKeyNotExistError("foo"))) } - -func TestStatusError_Error(t *testing.T) { - err := NewCollectionNotExistError("collection not exist") - assert.True(t, IsStatusError(err)) - assert.True(t, strings.Contains(err.Error(), "collection not exist")) -} - -func TestIsStatusError(t *testing.T) { - err := NewCollectionNotExistError("collection not exist") - assert.True(t, IsStatusError(err)) - assert.False(t, IsStatusError(errors.New("not status error"))) - assert.False(t, IsStatusError(nil)) -} - -func Test_IsCollectionNotExistError(t *testing.T) { - assert.False(t, IsCollectionNotExistError(nil)) - assert.False(t, IsCollectionNotExistError(errors.New("not status error"))) - for _, code := range collectionNotExistCodes { - err := NewStatusError(code, "can't find collection") - assert.True(t, IsCollectionNotExistError(err)) - } - assert.True(t, IsCollectionNotExistError(NewCollectionNotExistError("collection not exist"))) - assert.False(t, IsCollectionNotExistError(NewStatusError(commonpb.ErrorCode_BuildIndexError, ""))) -} - -func TestIsCollectionNotExistErrorV2(t *testing.T) { - assert.False(t, IsCollectionNotExistErrorV2(nil)) - assert.False(t, IsCollectionNotExistErrorV2(errors.New("not status error"))) - assert.True(t, IsCollectionNotExistErrorV2(NewCollectionNotExistError("collection not exist"))) - assert.False(t, IsCollectionNotExistErrorV2(NewStatusError(commonpb.ErrorCode_BuildIndexError, ""))) -} - -func TestStatusFromError(t *testing.T) { - var status *commonpb.Status - status = StatusFromError(nil) - assert.Equal(t, commonpb.ErrorCode_Success, status.GetErrorCode()) - status = StatusFromError(errors.New("not status error")) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) - assert.Equal(t, "not status error", status.GetReason()) - status = StatusFromError(NewCollectionNotExistError("collection not exist")) - assert.Equal(t, commonpb.ErrorCode_CollectionNotExists, status.GetErrorCode()) -} diff --git a/pkg/util/interceptor/server_id_interceptor.go b/pkg/util/interceptor/server_id_interceptor.go index aceae2cd82..636082f039 100644 --- a/pkg/util/interceptor/server_id_interceptor.go +++ b/pkg/util/interceptor/server_id_interceptor.go @@ -49,7 +49,7 @@ func ServerIDValidationUnaryServerInterceptor(fn GetServerIDFunc) grpc.UnaryServ } actualServerID := fn() if serverID != actualServerID { - return nil, merr.WrapErrServerIDMismatch(serverID, actualServerID) + return nil, merr.WrapErrNodeNotMatch(serverID, actualServerID) } return handler(ctx, req) } @@ -73,7 +73,7 @@ func ServerIDValidationStreamServerInterceptor(fn GetServerIDFunc) grpc.StreamSe } actualServerID := fn() if serverID != actualServerID { - return merr.WrapErrServerIDMismatch(serverID, actualServerID) + return merr.WrapErrNodeNotMatch(serverID, actualServerID) } return handler(srv, ss) } diff --git a/pkg/util/interceptor/server_id_interceptor_test.go b/pkg/util/interceptor/server_id_interceptor_test.go index e0e67c8029..8e813b54c3 100644 --- a/pkg/util/interceptor/server_id_interceptor_test.go +++ b/pkg/util/interceptor/server_id_interceptor_test.go @@ -101,7 +101,7 @@ func TestServerIDInterceptor(t *testing.T) { md = metadata.Pairs(ServerIDKey, "1234") ctx = metadata.NewIncomingContext(context.Background(), md) _, err = interceptor(ctx, req, serverInfo, handler) - assert.ErrorIs(t, err, merr.ErrServerIDMismatch) + assert.ErrorIs(t, err, merr.ErrNodeNotMatch) // with same ServerID md = metadata.Pairs(ServerIDKey, fmt.Sprint(paramtable.GetNodeID())) @@ -137,7 +137,7 @@ func TestServerIDInterceptor(t *testing.T) { md = metadata.Pairs(ServerIDKey, "1234") ctx = metadata.NewIncomingContext(context.Background(), md) err = interceptor(nil, newMockSS(ctx), nil, handler) - assert.ErrorIs(t, err, merr.ErrServerIDMismatch) + assert.ErrorIs(t, err, merr.ErrNodeNotMatch) // with same ServerID md = metadata.Pairs(ServerIDKey, fmt.Sprint(paramtable.GetNodeID())) diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 5bb21a9332..92e19ddc3a 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -40,7 +40,6 @@ var ( ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus ErrCrossClusterRouting = newMilvusError("cross cluster routing", 6, false) ErrServiceDiskLimitExceeded = newMilvusError("disk limit exceeded", 7, false) - ErrServerIDMismatch = newMilvusError("server ID mismatch", 8, false) // Collection related ErrCollectionNotFound = newMilvusError("collection not found", 100, false) @@ -60,12 +59,11 @@ var ( ErrReplicaNotFound = newMilvusError("replica not found", 400, false) ErrReplicaNotAvailable = newMilvusError("replica not available", 401, false) - // Channel related - ErrChannelNotFound = newMilvusError("channel not found", 500, false) - ErrChannelLack = newMilvusError("channel lacks", 501, false) - ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false) - ErrChannelNotAvailable = newMilvusError("channel not available", 503, false) - ErrChannelUnsubscribing = newMilvusError("chanel is unsubscribing", 504, true) + // Channel & Delegator related + ErrChannelNotFound = newMilvusError("channel not found", 500, false) + ErrChannelLack = newMilvusError("channel lacks", 501, false) + ErrChannelReduplicate = newMilvusError("channel reduplicates", 502, false) + ErrChannelNotAvailable = newMilvusError("channel not available", 503, false) // Segment related ErrSegmentNotFound = newMilvusError("segment not found", 600, false) @@ -102,15 +100,6 @@ var ( ErrTopicNotFound = newMilvusError("topic not found", 1300, false) ErrTopicNotEmpty = newMilvusError("topic not empty", 1301, false) - // shard delegator related - ErrShardDelegatorNotFound = newMilvusError("shard delegator not found", 1500, false) - ErrShardDelegatorAccessFailed = newMilvusError("fail to access shard delegator", 1501, true) - ErrShardDelegatorSearchFailed = newMilvusError("fail to search on all shard leaders", 1502, true) - ErrShardDelegatorQueryFailed = newMilvusError("fail to query on all shard leaders", 1503, true) - ErrShardDelegatorStatisticFailed = newMilvusError("get statistics on all shard leaders", 1504, true) - ErrShardDelegatorSQTimeout = newMilvusError("search/query on shard leader timeout", 1505, true) - ErrShardDelegatorSQFailed = newMilvusError("fail to search/query shard leader", 1506, true) - // field related ErrFieldNotFound = newMilvusError("field not found", 1700, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index f10354045c..4710ff7eb2 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -77,7 +77,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal) s.ErrorIs(WrapErrCrossClusterRouting("ins-0", "ins-1"), ErrCrossClusterRouting) s.ErrorIs(WrapErrServiceDiskLimitExceeded(110, 100, "DLE"), ErrServiceDiskLimitExceeded) - s.ErrorIs(WrapErrServerIDMismatch(0, 1, "SIM"), ErrServerIDMismatch) + s.ErrorIs(WrapErrNodeNotMatch(0, 1, "SIM"), ErrNodeNotMatch) // Collection related s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound) @@ -100,7 +100,6 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrChannelNotFound("test_Channel", "failed to get Channel"), ErrChannelNotFound) s.ErrorIs(WrapErrChannelLack("test_Channel", "failed to get Channel"), ErrChannelLack) s.ErrorIs(WrapErrChannelReduplicate("test_Channel", "failed to get Channel"), ErrChannelReduplicate) - s.ErrorIs(WrapErrChannelUnsubscribing("test_channel"), ErrChannelUnsubscribing) // Segment related s.ErrorIs(WrapErrSegmentNotFound(1, "failed to get Segment"), ErrSegmentNotFound) @@ -131,11 +130,6 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrTopicNotFound("unknown", "failed to get topic"), ErrTopicNotFound) s.ErrorIs(WrapErrTopicNotEmpty("unknown", "topic is not empty"), ErrTopicNotEmpty) - // shard delegator related - s.ErrorIs(WrapErrShardDelegatorNotFound("unknown", "fail to get shard delegator"), ErrShardDelegatorNotFound) - s.ErrorIs(WrapErrShardDelegatorSQFailed("fake"), ErrShardDelegatorSQFailed) - s.ErrorIs(WrapErrShardDelegatorSQTimeout("fake"), ErrShardDelegatorSQTimeout) - // field related s.ErrorIs(WrapErrFieldNotFound("meta", "failed to get field"), ErrFieldNotFound) } diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 748a157f49..3995299c49 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -60,6 +60,10 @@ func IsRetriable(err error) bool { return Code(err)&retriableFlag != 0 } +func IsCanceledOrTimeout(err error) bool { + return errors.IsAny(err, context.Canceled, context.DeadlineExceeded) +} + // Status returns a status according to the given err, // returns Success status if err is nil func Status(err error) *commonpb.Status { @@ -196,14 +200,6 @@ func WrapErrServiceDiskLimitExceeded(predict, limit float32, msg ...string) erro return err } -func WrapErrServerIDMismatch(expectedID, actualID int64, msg ...string) error { - err := errors.Wrapf(ErrServerIDMismatch, "expected=%s, actual=%s", expectedID, actualID) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - func WrapErrDatabaseNotFound(database any, msg ...string) error { err := wrapWithField(ErrDatabaseNotfound, "database", database) if len(msg) > 0 { @@ -237,6 +233,14 @@ func WrapErrCollectionNotFound(collection any, msg ...string) error { return err } +func WrapErrCollectionNotFoundWithDB(db any, collection any, msg ...string) error { + err := errors.Wrapf(ErrCollectionNotFound, "collection %v:%v", db, collection) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + func WrapErrCollectionNotLoaded(collection any, msg ...string) error { err := wrapWithField(ErrCollectionNotLoaded, "collection", collection) if len(msg) > 0 { @@ -345,14 +349,6 @@ func WrapErrChannelNotAvailable(name string, msg ...string) error { return err } -func WrapErrChannelUnsubscribing(name string, msg ...string) error { - err := wrapWithField(ErrChannelUnsubscribing, "channel", name) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - // Segment related func WrapErrSegmentNotFound(id int64, msg ...string) error { err := wrapWithField(ErrSegmentNotFound, "segment", id) @@ -463,18 +459,15 @@ func WrapErrParameterInvalid[T any](expected, actual T, msg ...string) error { } func WrapErrParameterInvalidRange[T any](lower, upper, actual T, msg ...string) error { - err := errors.Wrapf(ErrParameterInvalid, "expected in (%v, %v), actual=%v", lower, upper, actual) + err := errors.Wrapf(ErrParameterInvalid, "expected in [%v, %v], actual=%v", lower, upper, actual) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } return err } -func WrapErrParameterDuplicateFieldData(fieldName string, msg ...string) error { - err := errors.Wrapf(ErrParameterInvalid, "field name=%v", fieldName) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } +func WrapErrParameterInvalidMsg(fmt string, args ...any) error { + err := errors.Wrapf(ErrParameterInvalid, fmt, args...) return err } @@ -504,63 +497,6 @@ func WrapErrTopicNotEmpty(name string, msg ...string) error { return err } -// shard delegator related -func WrapErrShardDelegatorNotFound(channel string, msg ...string) error { - err := errors.Wrapf(ErrShardDelegatorNotFound, "channel=%s", channel) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorAccessFailed(channel string, msg ...string) error { - err := errors.Wrapf(ErrShardDelegatorAccessFailed, "channel=%s", channel) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorSQTimeout(channel string, msg ...string) error { - err := errors.Wrapf(ErrShardDelegatorSQTimeout, "channel=%s", channel) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorSQFailed(channel string, msg ...string) error { - err := errors.Wrapf(ErrShardDelegatorSQFailed, "channel=%s", channel) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorSearchFailed(msg ...string) error { - err := error(ErrShardDelegatorSearchFailed) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorQueryFailed(msg ...string) error { - err := error(ErrShardDelegatorQueryFailed) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - -func WrapErrShardDelegatorStatisticFailed(msg ...string) error { - err := error(ErrShardDelegatorStatisticFailed) - if len(msg) > 0 { - err = errors.Wrap(err, strings.Join(msg, "; ")) - } - return err -} - // field related func WrapErrFieldNotFound[T any](field T, msg ...string) error { err := errors.Wrapf(ErrFieldNotFound, "field=%v", field) diff --git a/tests/python_client/testcases/test_alias.py b/tests/python_client/testcases/test_alias.py index 15b3df5674..effb6433dd 100644 --- a/tests/python_client/testcases/test_alias.py +++ b/tests/python_client/testcases/test_alias.py @@ -43,7 +43,7 @@ class TestAliasParamsInvalid(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_name, exp_schema: default_schema}) - error = {ct.err_code: 1, ct.err_msg: f"Invalid collection alias: {alias_name}"} + error = {ct.err_code: 1, ct.err_msg: "Invalid collection alias"} self.utility_wrap.create_alias(collection_w.name, alias_name, check_task=CheckTasks.err_res, check_items=error) @@ -79,7 +79,8 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: alias_name, exp_schema: default_schema}) # assert collection is equal to alias according to partitions - assert [p.name for p in collection_w.partitions] == [p.name for p in collection_alias.partitions] + assert [p.name for p in collection_w.partitions] == [ + p.name for p in collection_alias.partitions] @pytest.mark.tags(CaseLabel.L1) def test_alias_alter_operation_default(self): @@ -105,8 +106,8 @@ class TestAliasOperation(TestcaseBase): partition_name = cf.gen_unique_str("partition") # create partition with different names and check the partition exists self.init_partition_wrap(collection_1, partition_name) - assert collection_1.has_partition(partition_name)[0] - + assert collection_1.has_partition(partition_name)[0] + alias_a_name = cf.gen_unique_str(prefix) self.utility_wrap.create_alias(collection_1.name, alias_a_name) collection_alias_a, _ = self.collection_wrap.init_collection(name=alias_a_name, @@ -114,8 +115,9 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: alias_a_name, exp_schema: default_schema}) # assert collection is equal to alias according to partitions - assert [p.name for p in collection_1.partitions] == [p.name for p in collection_alias_a.partitions] - + assert [p.name for p in collection_1.partitions] == [ + p.name for p in collection_alias_a.partitions] + # create collection_2 with 5 partitions and its alias alias_b c_2_name = cf.gen_unique_str("collection") collection_2 = self.init_collection_wrap(name=c_2_name, schema=default_schema, @@ -135,15 +137,19 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: alias_b_name, exp_schema: default_schema}) # assert collection is equal to alias according to partitions - assert [p.name for p in collection_2.partitions] == [p.name for p in collection_alias_b.partitions] - + assert [p.name for p in collection_2.partitions] == [ + p.name for p in collection_alias_b.partitions] + # collection_1 alter alias to alias_b self.utility_wrap.alter_alias(collection_1.name, alias_b_name) # collection_1 has two alias name, alias_a and alias_b, but collection_2 has no alias any more - assert [p.name for p in collection_1.partitions] == [p.name for p in collection_alias_b.partitions] - assert [p.name for p in collection_1.partitions] == [p.name for p in collection_alias_a.partitions] - assert [p.name for p in collection_2.partitions] != [p.name for p in collection_alias_b.partitions] + assert [p.name for p in collection_1.partitions] == [ + p.name for p in collection_alias_b.partitions] + assert [p.name for p in collection_1.partitions] == [ + p.name for p in collection_alias_a.partitions] + assert [p.name for p in collection_2.partitions] != [ + p.name for p in collection_alias_b.partitions] @pytest.mark.tags(CaseLabel.L1) def test_alias_drop_operation_default(self): @@ -176,7 +182,8 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: alias_name, exp_schema: default_schema}) # assert collection is equal to alias according to partitions - assert [p.name for p in collection_w.partitions] == [p.name for p in collection_alias.partitions] + assert [p.name for p in collection_w.partitions] == [ + p.name for p in collection_alias.partitions] self.utility_wrap.drop_alias(alias_name) error = {ct.err_code: 0, ct.err_msg: f"Collection '{alias_name}' not exist, or you can pass in schema to create one"} @@ -216,7 +223,7 @@ class TestAliasOperation(TestcaseBase): check_task=CheckTasks.check_collection_property, check_items={exp_name: alias_name, exp_schema: default_schema}) - + # create partition by alias partition_name = cf.gen_unique_str("partition") try: @@ -225,11 +232,11 @@ class TestAliasOperation(TestcaseBase): log.info(f"alias create partition failed with exception {e}") create_partition_flag = False collection_w.create_partition(partition_name) - + # assert partition pytest.assume(create_partition_flag is True and [p.name for p in collection_alias.partitions] == [p.name for p in collection_w.partitions]) - + # insert data by alias df = cf.gen_default_dataframe_data(ct.default_nb) try: @@ -238,26 +245,29 @@ class TestAliasOperation(TestcaseBase): log.info(f"alias insert data failed with exception {e}") insert_data_flag = False collection_w.insert(data=df) - + # assert insert data pytest.assume(insert_data_flag is True and collection_w.num_entities == ct.default_nb and collection_alias.num_entities == ct.default_nb) # create index by alias - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} try: - collection_alias.create_index(field_name="float_vector", index_params=default_index) + collection_alias.create_index( + field_name="float_vector", index_params=default_index) except Exception as e: log.info(f"alias create index failed with exception {e}") create_index_flag = False - collection_w.create_index(field_name="float_vector", index_params=default_index) - + collection_w.create_index( + field_name="float_vector", index_params=default_index) + # assert create index pytest.assume(create_index_flag is True and collection_alias.has_index() is True and collection_w.has_index()[0] is True) - + # load by alias try: collection_alias.load() @@ -271,8 +281,9 @@ class TestAliasOperation(TestcaseBase): # search by alias topK = 5 search_params = {"metric_type": "L2", "params": {"nprobe": 10}} - - query = [[random.random() for _ in range(ct.default_dim)] for _ in range(1)] + + query = [[random.random() for _ in range(ct.default_dim)] + for _ in range(1)] alias_res = None try: alias_res = collection_alias.search( @@ -282,13 +293,14 @@ class TestAliasOperation(TestcaseBase): except Exception as e: log.info(f"alias search failed with exception {e}") search_flag = False - + collection_res, _ = collection_w.search( query, "float_vector", search_params, topK, "int64 >= 0", output_fields=["int64"] ) # assert search - pytest.assume(search_flag is True and alias_res[0].ids == collection_res[0].ids) + pytest.assume( + search_flag is True and alias_res[0].ids == collection_res[0].ids) # release by alias try: @@ -314,7 +326,7 @@ class TestAliasOperation(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_name, exp_schema: default_schema}) - + alias_name = cf.gen_unique_str(prefix) self.utility_wrap.create_alias(collection_w.name, alias_name) # collection_w.create_alias(alias_name) @@ -340,7 +352,7 @@ class TestAliasOperation(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_name, exp_schema: default_schema}) - + alias_name = cf.gen_unique_str(prefix) self.utility_wrap.create_alias(collection_w.name, alias_name) # collection_w.create_alias(alias_name) @@ -349,7 +361,8 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: alias_name, exp_schema: default_schema}) assert self.utility_wrap.has_collection(c_name)[0] - error = {ct.err_code: 1, ct.err_msg: f"cannot drop the collection via alias = {alias_name}"} + error = {ct.err_code: 1, + ct.err_msg: f"cannot drop the collection via alias = {alias_name}"} self.utility_wrap.drop_collection(alias_name, check_task=CheckTasks.err_res, check_items=error) @@ -372,7 +385,7 @@ class TestAliasOperation(TestcaseBase): check_items={exp_name: c_name, exp_schema: default_schema}) partition_name = cf.gen_unique_str("partition") self.init_partition_wrap(collection_w, partition_name) - + alias_name = cf.gen_unique_str(prefix) self.utility_wrap.create_alias(collection_w.name, alias_name) # collection_w.create_alias(alias_name) @@ -411,7 +424,8 @@ class TestAliasOperationInvalid(TestcaseBase): collection_2 = self.init_collection_wrap(name=c_2_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_2_name, exp_schema: default_schema}) - error = {ct.err_code: 1, ct.err_msg: "Create alias failed: duplicate collection alias"} + error = {ct.err_code: 1, + ct.err_msg: "Create alias failed: duplicate collection alias"} self.utility_wrap.create_alias(collection_2.name, alias_a_name, check_task=CheckTasks.err_res, check_items=error) @@ -439,7 +453,8 @@ class TestAliasOperationInvalid(TestcaseBase): # collection_w.create_alias(alias_name) alias_not_exist_name = cf.gen_unique_str(prefix) - error = {ct.err_code: 1, ct.err_msg: "Alter alias failed: alias does not exist"} + error = {ct.err_code: 1, + ct.err_msg: "Alter alias failed: alias does not exist"} self.utility_wrap.alter_alias(collection_w.name, alias_not_exist_name, check_task=CheckTasks.err_res, check_items=error) @@ -466,7 +481,8 @@ class TestAliasOperationInvalid(TestcaseBase): # collection_w.create_alias(alias_name) alias_not_exist_name = cf.gen_unique_str(prefix) - error = {ct.err_code: 1, ct.err_msg: "Drop alias failed: alias does not exist"} + error = {ct.err_code: 1, + ct.err_msg: "Drop alias failed: alias does not exist"} # self.utility_wrap.drop_alias(alias_not_exist_name, # check_task=CheckTasks.err_res, # check_items=error) @@ -510,7 +526,7 @@ class TestAliasOperationInvalid(TestcaseBase): # collection_w.drop_alias(alias_name, # check_task=CheckTasks.err_res, # check_items=error) - + @pytest.mark.tags(CaseLabel.L1) def test_alias_create_dup_name_collection(self): """ @@ -556,6 +572,6 @@ class TestAliasOperationInvalid(TestcaseBase): check_task=CheckTasks.check_collection_property, check_items={exp_name: alias_name, exp_schema: default_schema}) - + with pytest.raises(Exception): collection_alias.drop() diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 71d8c00e7e..3b6a2313b6 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -1068,7 +1068,7 @@ class TestCollectionOperation(TestcaseBase): check_items={exp_name: c_name, exp_schema: default_schema}) self.collection_wrap.drop() assert not self.utility_wrap.has_collection(c_name)[0] - error = {ct.err_code: 1, ct.err_msg: f'HasPartition failed: can\'t find collection: {c_name}'} + error = {ct.err_code: 1, ct.err_msg: f'HasPartition failed: collection not found: {c_name}'} collection_w.has_partition("p", check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -1895,7 +1895,7 @@ class TestDropCollection(TestcaseBase): c_name = cf.gen_unique_str() self.init_collection_wrap(name=c_name) c_name_2 = cf.gen_unique_str() - # error = {ct.err_code: 0, ct.err_msg: 'DescribeCollection failed: can\'t find collection: %s' % c_name_2} + # error = {ct.err_code: 0, ct.err_msg: 'DescribeCollection failed: collection not found: %s' % c_name_2} # self.utility_wrap.drop_collection(c_name_2, check_task=CheckTasks.err_res, check_items=error) # @longjiquan: dropping collection should be idempotent. self.utility_wrap.drop_collection(c_name_2) @@ -3360,7 +3360,7 @@ class TestLoadPartition(TestcaseBase): "is_empty": True, "num_entities": 0} ) collection_w.drop() - error = {ct.err_code: 0, ct.err_msg: "can\'t find collection"} + error = {ct.err_code: 0, ct.err_msg: "collection not found"} partition_w.load(check_task=CheckTasks.err_res, check_items=error) partition_w.release(check_task=CheckTasks.err_res, check_items=error) diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 2f4457549a..fc3d585a18 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -177,8 +177,7 @@ class TestIndexParams(TestcaseBase): index_name=index_name, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, - ct.err_msg: "CreateIndex failed: index already exist, " - "but parameters are inconsistent"}) + ct.err_msg: "invalid parameter"}) @pytest.mark.tags(CaseLabel.L1) # @pytest.mark.xfail(reason="issue 19181") diff --git a/tests/python_client/testcases/test_insert.py b/tests/python_client/testcases/test_insert.py index a7083863a8..588cbe5ac0 100644 --- a/tests/python_client/testcases/test_insert.py +++ b/tests/python_client/testcases/test_insert.py @@ -23,8 +23,10 @@ exp_primary = "primary" default_float_name = ct.default_float_field_name default_schema = cf.gen_default_collection_schema() default_binary_schema = cf.gen_default_binary_collection_schema() -default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}} -default_binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}} +default_index_params = {"index_type": "IVF_SQ8", + "metric_type": "L2", "params": {"nlist": 64}} +default_binary_index_params = { + "index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}} default_search_exp = "int64 >= 0" @@ -56,7 +58,8 @@ class TestInsertParams(TestcaseBase): df = cf.gen_default_dataframe_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L0) @@ -83,8 +86,10 @@ class TestInsertParams(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - error = {ct.err_code: 1, ct.err_msg: "The type of data should be list or pandas.DataFrame"} - collection_w.insert(data=get_non_data_type, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The type of data should be list or pandas.DataFrame"} + collection_w.insert(data=get_non_data_type, + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("data", [[], pd.DataFrame()]) @@ -98,7 +103,8 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector'], got %s" % data} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_dataframe_only_columns(self): @@ -109,10 +115,13 @@ class TestInsertParams(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - columns = [ct.default_int64_field_name, ct.default_float_vec_field_name] + columns = [ct.default_int64_field_name, + ct.default_float_vec_field_name] df = pd.DataFrame(columns=columns) - error = {ct.err_code: 0, ct.err_msg: "Cannot infer schema from empty dataframe"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 0, + ct.err_msg: "Cannot infer schema from empty dataframe"} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_empty_field_name_dataframe(self): @@ -125,8 +134,10 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) df = cf.gen_default_dataframe_data(10) df.rename(columns={ct.default_int64_field_name: ' '}, inplace=True) - error = {ct.err_code: 1, ct.err_msg: "The name of field don't match, expected: int64, got "} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The name of field don't match, expected: int64, got "} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_invalid_field_name_dataframe(self, get_invalid_field_name): @@ -138,9 +149,12 @@ class TestInsertParams(TestcaseBase): c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) df = cf.gen_default_dataframe_data(10) - df.rename(columns={ct.default_int64_field_name: get_invalid_field_name}, inplace=True) - error = {ct.err_code: 1, ct.err_msg: "The name of field don't match, expected: int64, got %s" % get_invalid_field_name} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + df.rename( + columns={ct.default_int64_field_name: get_invalid_field_name}, inplace=True) + error = {ct.err_code: 1, ct.err_msg: "The name of field don't match, expected: int64, got %s" % + get_invalid_field_name} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) def test_insert_dataframe_index(self): """ @@ -185,11 +199,13 @@ class TestInsertParams(TestcaseBase): expected: assert num_entities """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L0) @@ -200,7 +216,8 @@ class TestInsertParams(TestcaseBase): expected: assert num_entities """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) data, _ = cf.gen_default_binary_list_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=data) assert mutation_res.insert_count == ct.default_nb @@ -235,7 +252,8 @@ class TestInsertParams(TestcaseBase): df = cf.gen_default_dataframe_data(ct.default_nb, dim=dim) error = {ct.err_code: 1, ct.err_msg: f'Collection field dim is {ct.default_dim}, but entities field dim is {dim}'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_binary_dim_not_match(self): @@ -245,12 +263,14 @@ class TestInsertParams(TestcaseBase): expected: raise exception """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) dim = 120 df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb, dim=dim) error = {ct.err_code: 1, ct.err_msg: f'Collection field dim is {ct.default_dim}, but entities field dim is {dim}'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_field_name_not_match(self): @@ -263,8 +283,10 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) df = cf.gen_default_dataframe_data(10) df.rename(columns={ct.default_float_field_name: "int"}, inplace=True) - error = {ct.err_code: 1, ct.err_msg: "The name of field don't match, expected: float, got int"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The name of field don't match, expected: float, got int"} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_field_value_not_match(self): @@ -277,10 +299,12 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) nb = 10 df = cf.gen_default_dataframe_data(nb) - new_float_value = pd.Series(data=[float(i) for i in range(nb)], dtype="float64") + new_float_value = pd.Series( + data=[float(i) for i in range(nb)], dtype="float64") df[df.columns[1]] = new_float_value - error = {ct.err_code: 1, ct.err_msg: 'The data fields number is not match with schema.'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 5} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_value_less(self): @@ -297,7 +321,8 @@ class TestInsertParams(TestcaseBase): float_vec_values = cf.gen_vectors(nb, ct.default_dim) data = [int_values, float_values, float_vec_values] error = {ct.err_code: 1, ct.err_msg: 'Arrays must all be same length.'} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_vector_value_less(self): @@ -314,7 +339,8 @@ class TestInsertParams(TestcaseBase): float_vec_values = cf.gen_vectors(nb - 1, ct.default_dim) data = [int_values, float_values, float_vec_values] error = {ct.err_code: 1, ct.err_msg: 'Arrays must all be same length.'} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_fields_more(self): @@ -331,7 +357,8 @@ class TestInsertParams(TestcaseBase): error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector'], " "got ['int64', 'float', 'varchar', 'new', 'float_vector']"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_fields_less(self): @@ -347,7 +374,8 @@ class TestInsertParams(TestcaseBase): error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector'], " "got ['int64', 'float', 'varchar']"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_list_order_inconsistent_schema(self): @@ -363,8 +391,9 @@ class TestInsertParams(TestcaseBase): float_values = [np.float32(i) for i in range(nb)] float_vec_values = cf.gen_vectors(nb, ct.default_dim) data = [float_values, int_values, float_vec_values] - error = {ct.err_code: 1, ct.err_msg: 'The data fields number is not match with schema.'} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 5} + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_dataframe_order_inconsistent_schema(self): @@ -377,15 +406,17 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) nb = 10 int_values = pd.Series(data=[i for i in range(nb)]) - float_values = pd.Series(data=[float(i) for i in range(nb)], dtype="float32") + float_values = pd.Series(data=[float(i) + for i in range(nb)], dtype="float32") float_vec_values = cf.gen_vectors(nb, ct.default_dim) df = pd.DataFrame({ ct.default_float_field_name: float_values, ct.default_float_vec_field_name: float_vec_values, ct.default_int64_field_name: int_values }) - error = {ct.err_code: 1, ct.err_msg: 'The data fields number is not match with schema.'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 5, ct.err_msg: 'Missing param in entities'} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_inconsistent_data(self): @@ -398,8 +429,10 @@ class TestInsertParams(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) data = cf.gen_default_list_data(nb=100) data[0][1] = 1.0 - error = {ct.err_code: 0, ct.err_msg: "The data in the same column must be of the same type"} - collection_w.insert(data, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 0, + ct.err_msg: "The data in the same column must be of the same type"} + collection_w.insert( + data, check_task=CheckTasks.err_res, check_items=error) class TestInsertOperation(TestcaseBase): @@ -435,7 +468,8 @@ class TestInsertOperation(TestcaseBase): assert ct.default_alias not in res_list data = cf.gen_default_list_data(10) error = {ct.err_code: 0, ct.err_msg: 'should create connect first'} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_default_partition(self): @@ -444,10 +478,12 @@ class TestInsertOperation(TestcaseBase): method: create partition and insert info collection expected: the collection insert count equals to nb """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w1 = self.init_partition_wrap(collection_w) data = cf.gen_default_list_data(nb=ct.default_nb) - mutation_res, _ = collection_w.insert(data=data, partition_name=partition_w1.name) + mutation_res, _ = collection_w.insert( + data=data, partition_name=partition_w1.name) assert mutation_res.insert_count == ct.default_nb def test_insert_partition_not_existed(self): @@ -456,9 +492,11 @@ class TestInsertOperation(TestcaseBase): method: create collection and insert entities in it, with the not existed partition_name param expected: error raised """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=ct.default_nb) - error = {ct.err_code: 1, ct.err_msg: "partitionID of partitionName:p can not be existed"} + error = {ct.err_code: 1, + ct.err_msg: "partitionID of partitionName:p can not be existed"} mutation_res, _ = collection_w.insert(data=df, partition_name="p", check_task=CheckTasks.err_res, check_items=error) @@ -469,12 +507,15 @@ class TestInsertOperation(TestcaseBase): method: create collection and insert entities in it repeatedly, with the partition_name param expected: the collection row count equals to nq """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w1 = self.init_partition_wrap(collection_w) partition_w2 = self.init_partition_wrap(collection_w) df = cf.gen_default_dataframe_data(nb=ct.default_nb) - mutation_res, _ = collection_w.insert(data=df, partition_name=partition_w1.name) - new_res, _ = collection_w.insert(data=df, partition_name=partition_w2.name) + mutation_res, _ = collection_w.insert( + data=df, partition_name=partition_w1.name) + new_res, _ = collection_w.insert( + data=df, partition_name=partition_w2.name) assert mutation_res.insert_count == ct.default_nb assert new_res.insert_count == ct.default_nb @@ -485,11 +526,13 @@ class TestInsertOperation(TestcaseBase): method: create collection and insert entities in it, with the partition_name param expected: the collection insert count equals to nq """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_name = cf.gen_unique_str(prefix) partition_w1 = self.init_partition_wrap(collection_w, partition_name) df = cf.gen_default_dataframe_data(ct.default_nb) - mutation_res, _ = collection_w.insert(data=df, partition_name=partition_w1.name) + mutation_res, _ = collection_w.insert( + data=df, partition_name=partition_w1.name) assert mutation_res.insert_count == ct.default_nb @pytest.mark.tags(CaseLabel.L2) @@ -499,10 +542,13 @@ class TestInsertOperation(TestcaseBase): method: update entity field type expected: error raised """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_collection_schema_all_datatype - error = {ct.err_code: 1, ct.err_msg: "The type of data should be list or pandas.DataFrame"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The type of data should be list or pandas.DataFrame"} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_exceed_varchar_limit(self): @@ -521,9 +567,12 @@ class TestInsertOperation(TestcaseBase): name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name, schema) vectors = cf.gen_vectors(2, ct.default_dim) - data = [vectors, ["limit_1___________", "limit_2___________"], ['1', '2']] - error = {ct.err_code: 1, ct.err_msg: "invalid input, length of string exceeds max length"} - collection_w.insert(data, check_task=CheckTasks.err_res, check_items=error) + data = [vectors, ["limit_1___________", + "limit_2___________"], ['1', '2']] + error = {ct.err_code: 1, + ct.err_msg: "invalid input, length of string exceeds max length"} + collection_w.insert( + data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_lack_vector_field(self): @@ -532,10 +581,12 @@ class TestInsertOperation(TestcaseBase): method: remove entity values of vector field expected: error raised """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_collection_schema([cf.gen_int64_field(is_primary=True)]) error = {ct.err_code: 1, ct.err_msg: "Data type is not support."} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_no_vector_field_dtype(self): @@ -544,13 +595,16 @@ class TestInsertOperation(TestcaseBase): method: vector field dtype is not existed expected: error raised """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - vec_field, _ = self.field_schema_wrap.init_field_schema(name=ct.default_int64_field_name, dtype=DataType.NONE) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + vec_field, _ = self.field_schema_wrap.init_field_schema( + name=ct.default_int64_field_name, dtype=DataType.NONE) field_one = cf.gen_int64_field(is_primary=True) field_two = cf.gen_int64_field() df = [field_one, field_two, vec_field] error = {ct.err_code: 1, ct.err_msg: "Field dtype must be of DataType."} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_no_vector_field_name(self): @@ -559,13 +613,15 @@ class TestInsertOperation(TestcaseBase): method: vector field name is error expected: error raised """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) vec_field = cf.gen_float_vec_field(name=ct.get_invalid_strs) field_one = cf.gen_int64_field(is_primary=True) field_two = cf.gen_int64_field() df = [field_one, field_two, vec_field] error = {ct.err_code: 1, ct.err_msg: "data should be a list of list"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_drop_collection(self): @@ -591,14 +647,17 @@ class TestInsertOperation(TestcaseBase): method: 1. insert 2. create index expected: verify num entities and index """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(ct.default_nb) collection_w.insert(data=df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] index, _ = collection_w.index() - assert index == Index(collection_w.collection, ct.default_float_vec_field_name, default_index_params) + assert index == Index( + collection_w.collection, ct.default_float_vec_field_name, default_index_params) assert collection_w.indexes[0] == index @pytest.mark.tags(CaseLabel.L1) @@ -608,11 +667,14 @@ class TestInsertOperation(TestcaseBase): method: 1. create index 2. insert data expected: verify index and num entities """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] index, _ = collection_w.index() - assert index == Index(collection_w.collection, ct.default_float_vec_field_name, default_index_params) + assert index == Index( + collection_w.collection, ct.default_float_vec_field_name, default_index_params) assert collection_w.indexes[0] == index df = cf.gen_default_dataframe_data(ct.default_nb) collection_w.insert(data=df) @@ -626,11 +688,14 @@ class TestInsertOperation(TestcaseBase): expected: 1.index ok 2.num entities correct """ schema = cf.gen_default_binary_collection_schema() - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) - collection_w.create_index(ct.default_binary_vec_field_name, default_binary_index_params) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) + collection_w.create_index( + ct.default_binary_vec_field_name, default_binary_index_params) assert collection_w.has_index()[0] index, _ = collection_w.index() - assert index == Index(collection_w.collection, ct.default_binary_vec_field_name, default_binary_index_params) + assert index == Index( + collection_w.collection, ct.default_binary_vec_field_name, default_binary_index_params) assert collection_w.indexes[0] == index df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) collection_w.insert(data=df) @@ -645,17 +710,20 @@ class TestInsertOperation(TestcaseBase): expected: index correct """ schema = cf.gen_default_collection_schema(auto_id=True) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_default_dataframe_data() df.drop(ct.default_int64_field_name, axis=1, inplace=True) mutation_res, _ = collection_w.insert(data=df) assert cf._check_primary_keys(mutation_res.primary_keys, ct.default_nb) assert collection_w.num_entities == ct.default_nb # create index - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] index, _ = collection_w.index() - assert index == Index(collection_w.collection, ct.default_float_vec_field_name, default_index_params) + assert index == Index( + collection_w.collection, ct.default_float_vec_field_name, default_index_params) assert collection_w.indexes[0] == index @pytest.mark.tags(CaseLabel.L2) @@ -666,7 +734,8 @@ class TestInsertOperation(TestcaseBase): expected: verify primary_keys and num_entities """ c_name = cf.gen_unique_str(prefix) - schema = cf.gen_default_collection_schema(primary_field=pk_field, auto_id=True) + schema = cf.gen_default_collection_schema( + primary_field=pk_field, auto_id=True) collection_w = self.init_collection_wrap(name=c_name, schema=schema) df = cf.gen_default_dataframe_data() df.drop(pk_field, axis=1, inplace=True) @@ -682,7 +751,8 @@ class TestInsertOperation(TestcaseBase): expected: verify primary_keys unique """ c_name = cf.gen_unique_str(prefix) - schema = cf.gen_default_collection_schema(primary_field=pk_field, auto_id=True) + schema = cf.gen_default_collection_schema( + primary_field=pk_field, auto_id=True) nb = 10 collection_w = self.init_collection_wrap(name=c_name, schema=schema) df = cf.gen_default_dataframe_data(nb) @@ -703,7 +773,8 @@ class TestInsertOperation(TestcaseBase): expected: assert num entities """ c_name = cf.gen_unique_str(prefix) - schema = cf.gen_default_collection_schema(primary_field=pk_field, auto_id=True) + schema = cf.gen_default_collection_schema( + primary_field=pk_field, auto_id=True) collection_w = self.init_collection_wrap(name=c_name, schema=schema) data = cf.gen_default_list_data() if pk_field == ct.default_int64_field_name: @@ -723,11 +794,14 @@ class TestInsertOperation(TestcaseBase): expected: 1.verify num entities 2.verify ids """ c_name = cf.gen_unique_str(prefix) - schema = cf.gen_default_collection_schema(primary_field=pk_field, auto_id=True) + schema = cf.gen_default_collection_schema( + primary_field=pk_field, auto_id=True) collection_w = self.init_collection_wrap(name=c_name, schema=schema) df = cf.gen_default_dataframe_data(nb=100) - error = {ct.err_code: 1, ct.err_msg: "Please don't provide data for auto_id primary field: int64"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "Please don't provide data for auto_id primary field: int64"} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) assert collection_w.is_empty @pytest.mark.tags(CaseLabel.L2) @@ -738,12 +812,14 @@ class TestInsertOperation(TestcaseBase): expected: 1.verify num entities 2.verify ids """ c_name = cf.gen_unique_str(prefix) - schema = cf.gen_default_collection_schema(primary_field=pk_field, auto_id=True) + schema = cf.gen_default_collection_schema( + primary_field=pk_field, auto_id=True) collection_w = self.init_collection_wrap(name=c_name, schema=schema) data = cf.gen_default_list_data(nb=100) error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['float', 'varchar', 'float_vector'], got ['', '', '', '']"} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) assert collection_w.is_empty @pytest.mark.tags(CaseLabel.L1) @@ -786,7 +862,8 @@ class TestInsertOperation(TestcaseBase): method: multi threads insert expected: verify num entities """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(ct.default_nb) thread_num = 4 threads = [] @@ -830,7 +907,8 @@ class TestInsertOperation(TestcaseBase): df = cf.gen_default_dataframe_data(step, dim) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == step - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == nb @@ -917,7 +995,8 @@ class TestInsertOperation(TestcaseBase): cf.gen_string_field(default_value="abc")] schema = cf.gen_collection_schema(fields, auto_id=auto_id) collection_w = self.init_collection_wrap(schema=schema) - data = [[i for i in range(ct.default_nb)], cf.gen_vectors(ct.default_nb, ct.default_dim)] + data = [[i for i in range(ct.default_nb)], cf.gen_vectors( + ct.default_nb, ct.default_dim)] data1 = [[i for i in range(ct.default_nb)], cf.gen_vectors(ct.default_nb, ct.default_dim), [np.float32(i) for i in range(ct.default_nb)]] if auto_id: @@ -964,13 +1043,15 @@ class TestInsertAsync(TestcaseBase): method: insert with async=True expected: verify num entities """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() future, _ = collection_w.insert(data=df, _async=True) future.done() mutation_res = future.result() assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L1) @@ -980,11 +1061,13 @@ class TestInsertAsync(TestcaseBase): method: async = false expected: verify num entities """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() mutation_res, _ = collection_w.insert(data=df, _async=False) assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L1) @@ -994,12 +1077,15 @@ class TestInsertAsync(TestcaseBase): method: insert with callback func expected: verify num entities """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() - future, _ = collection_w.insert(data=df, _async=True, _callback=assert_mutation_result) + future, _ = collection_w.insert( + data=df, _async=True, _callback=assert_mutation_result) future.done() mutation_res = future.result() - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L2) @@ -1010,13 +1096,15 @@ class TestInsertAsync(TestcaseBase): expected: verify num entities """ nb = 50000 - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb) future, _ = collection_w.insert(data=df, _async=True) future.done() mutation_res = future.result() assert mutation_res.insert_count == nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( + ) assert collection_w.num_entities == nb @pytest.mark.tags(CaseLabel.L2) @@ -1027,9 +1115,11 @@ class TestInsertAsync(TestcaseBase): expected: raise exception """ nb = 100000 - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb) - future, _ = collection_w.insert(data=df, _async=True, _callback=None, timeout=0.2) + future, _ = collection_w.insert( + data=df, _async=True, _callback=None, timeout=0.2) with pytest.raises(MilvusException): future.result() @@ -1040,11 +1130,15 @@ class TestInsertAsync(TestcaseBase): method: insert async with invalid data expected: raise exception """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - columns = [ct.default_int64_field_name, ct.default_float_vec_field_name] + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + columns = [ct.default_int64_field_name, + ct.default_float_vec_field_name] df = pd.DataFrame(columns=columns) - error = {ct.err_code: 0, ct.err_msg: "Cannot infer schema from empty dataframe"} - collection_w.insert(data=df, _async=True, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 0, + ct.err_msg: "Cannot infer schema from empty dataframe"} + collection_w.insert(data=df, _async=True, + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_async_invalid_partition(self): @@ -1053,10 +1147,12 @@ class TestInsertAsync(TestcaseBase): method: insert async with invalid partition expected: raise exception """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() err_msg = "partition=p: partition not found" - future, _ = collection_w.insert(data=df, partition_name="p", _async=True) + future, _ = collection_w.insert( + data=df, partition_name="p", _async=True) future.done() with pytest.raises(MilvusException, match=err_msg): future.result() @@ -1068,10 +1164,12 @@ class TestInsertAsync(TestcaseBase): method: set only vector field and insert into collection expected: raise exception """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_collection_schema([cf.gen_int64_field(is_primary=True)]) error = {ct.err_code: 1, ct.err_msg: "fleldSchema lack of vector field."} - future, _ = collection_w.insert(data=df, _async=True, check_task=CheckTasks.err_res, check_items=error) + future, _ = collection_w.insert( + data=df, _async=True, check_task=CheckTasks.err_res, check_items=error) def assert_mutation_result(mutation_res): @@ -1088,11 +1186,13 @@ class TestInsertBinary(TestcaseBase): expected: the collection row count equals to nb """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) partition_name = cf.gen_unique_str(prefix) partition_w1 = self.init_partition_wrap(collection_w, partition_name) - mutation_res, _ = collection_w.insert(data=df, partition_name=partition_w1.name) + mutation_res, _ = collection_w.insert( + data=df, partition_name=partition_w1.name) assert mutation_res.insert_count == ct.default_nb @pytest.mark.tags(CaseLabel.L1) @@ -1103,7 +1203,8 @@ class TestInsertBinary(TestcaseBase): expected: the collection row count equals to nb """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) nums = 2 for i in range(nums): @@ -1118,11 +1219,13 @@ class TestInsertBinary(TestcaseBase): expected: no error raised """ c_name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=c_name, schema=default_binary_schema) + collection_w = self.init_collection_wrap( + name=c_name, schema=default_binary_schema) df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == ct.default_nb - default_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": "BIN_IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) @@ -1145,8 +1248,10 @@ class TestInsertInvalid(TestcaseBase): int_field = cf.gen_float_field(is_primary=True) vec_field = cf.gen_float_vec_field(name='vec') df = [int_field, vec_field] - error = {ct.err_code: 1, ct.err_msg: "Primary key type must be DataType.INT64."} - mutation_res, _ = collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "Primary key type must be DataType.INT64."} + mutation_res, _ = collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_partition_name(self): @@ -1176,7 +1281,8 @@ class TestInsertInvalid(TestcaseBase): vec_field = ct.get_invalid_vectors df = [field_one, field_two, vec_field] error = {ct.err_code: 1, ct.err_msg: "Data type is not support."} - mutation_res, _ = collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + mutation_res, _ = collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_invalid_with_pk_varchar_auto_id_true(self): @@ -1187,9 +1293,11 @@ class TestInsertInvalid(TestcaseBase): """ string_field = cf.gen_string_field(is_primary=True, max_length=6) embedding_field = cf.gen_float_vec_field() - schema = cf.gen_collection_schema([string_field, embedding_field], auto_id=True) + schema = cf.gen_collection_schema( + [string_field, embedding_field], auto_id=True) collection_w = self.init_collection_wrap(schema=schema) - data = [[[random.random() for _ in range(ct.default_dim)] for _ in range(2)]] + data = [[[random.random() for _ in range(ct.default_dim)] + for _ in range(2)]] res = collection_w.insert(data=data)[0] assert res.insert_count == 2 @@ -1201,12 +1309,14 @@ class TestInsertInvalid(TestcaseBase): method: insert int8 out of range expected: raise exception """ - collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + prefix, is_all_data_type=True)[0] data = cf.gen_dataframe_all_data_type(nb=1) data[ct.default_int8_field_name] = [invalid_int8] error = {ct.err_code: 1, 'err_msg': "The data type of field int8 doesn't match, " "expected: INT8, got INT64"} - collection_w.insert(data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("invalid_int16", [-32769, 32768]) @@ -1216,12 +1326,14 @@ class TestInsertInvalid(TestcaseBase): method: insert int16 out of range expected: raise exception """ - collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + prefix, is_all_data_type=True)[0] data = cf.gen_dataframe_all_data_type(nb=1) data[ct.default_int16_field_name] = [invalid_int16] error = {ct.err_code: 1, 'err_msg': "The data type of field int16 doesn't match, " "expected: INT16, got INT64"} - collection_w.insert(data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("invalid_int32", [-2147483649, 2147483648]) @@ -1231,12 +1343,14 @@ class TestInsertInvalid(TestcaseBase): method: insert int32 out of range expected: raise exception """ - collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + prefix, is_all_data_type=True)[0] data = cf.gen_dataframe_all_data_type(nb=1) data[ct.default_int32_field_name] = [invalid_int32] error = {ct.err_code: 1, 'err_msg': "The data type of field int16 doesn't match, " "expected: INT32, got INT64"} - collection_w.insert(data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("no error code provided now") @@ -1252,7 +1366,8 @@ class TestInsertInvalid(TestcaseBase): data = cf.gen_default_dataframe_data(nb) error = {ct.err_code: 1, ct.err_msg: "<_MultiThreadedRendezvous of RPC that terminated with:" "status = StatusCode.RESOURCE_EXHAUSTED"} - collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("not support default_value now") @@ -1268,7 +1383,8 @@ class TestInsertInvalid(TestcaseBase): schema = cf.gen_collection_schema(fields) collection_w = self.init_collection_wrap(schema=schema) vectors = cf.gen_vectors(ct.default_nb, ct.default_dim) - data = [{"int64": 1, "float_vector": vectors[1], "varchar": default_value, "float": np.float32(1.0)}] + data = [{"int64": 1, "float_vector": vectors[1], + "varchar": default_value, "float": np.float32(1.0)}] collection_w.insert(data, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "Field varchar don't match in entities[0]"}) @@ -1315,7 +1431,8 @@ class TestInsertInvalidBinary(TestcaseBase): dtype=DataType.BINARY_VECTOR) df = [field_one, field_two, vec_field] error = {ct.err_code: 1, ct.err_msg: "data should be a list of list"} - mutation_res, _ = collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + mutation_res, _ = collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_with_invalid_binary_partition_name(self): @@ -1328,7 +1445,8 @@ class TestInsertInvalidBinary(TestcaseBase): collection_w = self.init_collection_wrap(name=collection_name) partition_name = ct.get_invalid_strs df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) - error = {ct.err_code: 1, 'err_msg': "The types of schema and data do not match."} + error = {ct.err_code: 1, + 'err_msg': "The types of schema and data do not match."} mutation_res, _ = collection_w.insert(data=df, partition_name=partition_name, check_task=CheckTasks.err_res, check_items=error) @@ -1358,7 +1476,8 @@ class TestInsertString(TestcaseBase): @pytest.mark.tags(CaseLabel.L0) @pytest.mark.parametrize("string_fields", [[cf.gen_string_field(name="string_field1")], - [cf.gen_string_field(name="string_field2")], + [cf.gen_string_field( + name="string_field2")], [cf.gen_string_field(name="string_field3")]]) def test_insert_multi_string_fields(self, string_fields): """ @@ -1369,7 +1488,8 @@ class TestInsertString(TestcaseBase): """ schema = cf.gen_schema_multi_string_fields(string_fields) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_string_fields(string_fields=string_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb @@ -1386,10 +1506,13 @@ class TestInsertString(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) nb = 10 df = cf.gen_default_dataframe_data(nb) - new_float_value = pd.Series(data=[float(i) for i in range(nb)], dtype="float64") + new_float_value = pd.Series( + data=[float(i) for i in range(nb)], dtype="float64") df[df.columns[2]] = new_float_value - error = {ct.err_code: 1, ct.err_msg: "The data type of field varchar doesn't match, expected: VARCHAR, got DOUBLE"} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The data type of field varchar doesn't match, expected: VARCHAR, got DOUBLE"} + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) def test_insert_string_field_name_invalid(self): @@ -1401,9 +1524,11 @@ class TestInsertString(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - df = [cf.gen_int64_field(), cf.gen_string_field(name=ct.get_invalid_strs), cf.gen_float_vec_field()] + df = [cf.gen_int64_field(), cf.gen_string_field( + name=ct.get_invalid_strs), cf.gen_float_vec_field()] error = {ct.err_code: 1, ct.err_msg: 'data should be a list of list'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) def test_insert_string_field_length_exceed(self): @@ -1422,7 +1547,8 @@ class TestInsertString(TestcaseBase): vec_field = cf.gen_float_vec_field() df = [field_one, field_two, field_three, vec_field] error = {ct.err_code: 1, ct.err_msg: 'data should be a list of list'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_string_field_dtype_invalid(self): @@ -1434,12 +1560,14 @@ class TestInsertString(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - string_field = self.field_schema_wrap.init_field_schema(name="string", dtype=DataType.STRING)[0] + string_field = self.field_schema_wrap.init_field_schema( + name="string", dtype=DataType.STRING)[0] int_field = cf.gen_int64_field(is_primary=True) vec_field = cf.gen_float_vec_field() df = [string_field, int_field, vec_field] error = {ct.err_code: 1, ct.err_msg: 'data should be a list of list'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_string_field_auto_id_is_true(self): @@ -1456,7 +1584,8 @@ class TestInsertString(TestcaseBase): string_field = cf.gen_string_field(is_primary=True, auto_id=True) df = [int_field, string_field, vec_field] error = {ct.err_code: 1, ct.err_msg: 'data should be a list of list'} - collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) + collection_w.insert( + data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) def test_insert_string_field_space(self): @@ -1470,7 +1599,7 @@ class TestInsertString(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) nb = 1000 data = cf.gen_default_list_data(nb) - data[2] = [" "for _ in range(nb)] + data[2] = [" "for _ in range(nb)] collection_w.insert(data) assert collection_w.num_entities == nb @@ -1486,7 +1615,7 @@ class TestInsertString(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) nb = 1000 data = cf.gen_default_list_data(nb) - data[2] = [""for _ in range(nb)] + data[2] = [""for _ in range(nb)] collection_w.insert(data) assert collection_w.num_entities == nb @@ -1503,7 +1632,7 @@ class TestInsertString(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name, schema=schema) nb = 1000 data = cf.gen_default_list_data(nb) - data[2] = [""for _ in range(nb)] + data[2] = [""for _ in range(nb)] collection_w.insert(data) assert collection_w.num_entities == nb @@ -1536,11 +1665,13 @@ class TestUpsertValid(TestcaseBase): """ upsert_nb = 1000 collection_w = self.init_collection_general(pre_upsert, True)[0] - upsert_data, float_values = cf.gen_default_data_for_upsert(upsert_nb, start=start) + upsert_data, float_values = cf.gen_default_data_for_upsert( + upsert_nb, start=start) collection_w.upsert(data=upsert_data) exp = f"int64 >= {start} && int64 <= {upsert_nb + start}" res = collection_w.query(exp, output_fields=[default_float_name])[0] - assert [res[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() + assert [res[i][default_float_name] + for i in range(upsert_nb)] == float_values.to_list() @pytest.mark.tags(CaseLabel.L2) def test_upsert_with_primary_key_string(self): @@ -1552,10 +1683,13 @@ class TestUpsertValid(TestcaseBase): expected: raise no exception """ c_name = cf.gen_unique_str(pre_upsert) - fields = [cf.gen_string_field(), cf.gen_float_vec_field(dim=ct.default_dim)] - schema = cf.gen_collection_schema(fields=fields, primary_field=ct.default_string_field_name) + fields = [cf.gen_string_field(), cf.gen_float_vec_field( + dim=ct.default_dim)] + schema = cf.gen_collection_schema( + fields=fields, primary_field=ct.default_string_field_name) collection_w = self.init_collection_wrap(name=c_name, schema=schema) - vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(2)] + vectors = [[random.random() for _ in range(ct.default_dim)] + for _ in range(2)] collection_w.insert([["a", "b"], vectors]) collection_w.upsert([[" a", "b "], vectors]) assert collection_w.num_entities == 4 @@ -1571,12 +1705,14 @@ class TestUpsertValid(TestcaseBase): """ nb = 500 c_name = cf.gen_unique_str(pre_upsert) - collection_w = self.init_collection_general(c_name, True, is_binary=True)[0] + collection_w = self.init_collection_general( + c_name, True, is_binary=True)[0] binary_vectors = cf.gen_binary_vectors(nb, ct.default_dim)[1] data = [[i for i in range(nb)], [np.float32(i) for i in range(nb)], [str(i) for i in range(nb)], binary_vectors] collection_w.upsert(data) - res = collection_w.query("int64 >= 0", [ct.default_binary_vec_field_name])[0] + res = collection_w.query( + "int64 >= 0", [ct.default_binary_vec_field_name])[0] assert binary_vectors[0] == res[0][ct. default_binary_vec_field_name][0] @pytest.mark.tags(CaseLabel.L1) @@ -1606,7 +1742,8 @@ class TestUpsertValid(TestcaseBase): 3. upsert data=None expected: raise no exception """ - collection_w = self.init_collection_general(pre_upsert, insert_data=True, is_index=False)[0] + collection_w = self.init_collection_general( + pre_upsert, insert_data=True, is_index=False)[0] assert collection_w.num_entities == ct.default_nb collection_w.upsert(data=None) assert collection_w.num_entities == ct.default_nb @@ -1625,7 +1762,8 @@ class TestUpsertValid(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) collection_w.create_partition("partition_new") cf.insert_data(collection_w) - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the ids which will be upserted is in partition _default @@ -1634,8 +1772,10 @@ class TestUpsertValid(TestcaseBase): res0 = collection_w.query(expr, [default_float_name], ["_default"])[0] assert len(res0) == upsert_nb collection_w.flush() - res1 = collection_w.query(expr, [default_float_name], ["partition_new"])[0] - assert collection_w.partition('partition_new')[0].num_entities == ct.default_nb // 2 + res1 = collection_w.query( + expr, [default_float_name], ["partition_new"])[0] + assert collection_w.partition('partition_new')[ + 0].num_entities == ct.default_nb // 2 # upsert ids in partition _default data, float_values = cf.gen_default_data_for_upsert(upsert_nb) @@ -1644,10 +1784,13 @@ class TestUpsertValid(TestcaseBase): # check the result in partition _default(upsert successfully) and others(no missing, nothing new) collection_w.flush() res0 = collection_w.query(expr, [default_float_name], ["_default"])[0] - res2 = collection_w.query(expr, [default_float_name], ["partition_new"])[0] + res2 = collection_w.query( + expr, [default_float_name], ["partition_new"])[0] assert res1 == res2 - assert [res0[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() - assert collection_w.partition('partition_new')[0].num_entities == ct.default_nb // 2 + assert [res0[i][default_float_name] + for i in range(upsert_nb)] == float_values.to_list() + assert collection_w.partition('partition_new')[ + 0].num_entities == ct.default_nb // 2 @pytest.mark.tags(CaseLabel.L2) # @pytest.mark.skip(reason="issue #22592") @@ -1667,13 +1810,15 @@ class TestUpsertValid(TestcaseBase): # insert data and load collection cf.insert_data(collection_w) - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the ids which will be upserted is not in partition 'partition_1' upsert_nb = 100 expr = f"int64 >= 0 && int64 <= {upsert_nb}" - res = collection_w.query(expr, [default_float_name], ["partition_1"])[0] + res = collection_w.query( + expr, [default_float_name], ["partition_1"])[0] assert len(res) == 0 # upsert in partition 'partition_1' @@ -1681,8 +1826,10 @@ class TestUpsertValid(TestcaseBase): collection_w.upsert(data, "partition_1") # check the upserted data in 'partition_1' - res1 = collection_w.query(expr, [default_float_name], ["partition_1"])[0] - assert [res1[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() + res1 = collection_w.query( + expr, [default_float_name], ["partition_1"])[0] + assert [res1[i][default_float_name] + for i in range(upsert_nb)] == float_values.to_list() @pytest.mark.tags(CaseLabel.L1) def test_upsert_same_pk_concurrently(self): @@ -1696,7 +1843,8 @@ class TestUpsertValid(TestcaseBase): # initialize a collection upsert_nb = 1000 collection_w = self.init_collection_general(pre_upsert, True)[0] - data1, float_values1 = cf.gen_default_data_for_upsert(upsert_nb, size=1000) + data1, float_values1 = cf.gen_default_data_for_upsert( + upsert_nb, size=1000) data2, float_values2 = cf.gen_default_data_for_upsert(upsert_nb) # upsert at the same time @@ -1716,7 +1864,8 @@ class TestUpsertValid(TestcaseBase): # check the result exp = f"int64 >= 0 && int64 <= {upsert_nb}" - res = collection_w.query(exp, [default_float_name], consistency_level="Strong")[0] + res = collection_w.query( + exp, [default_float_name], consistency_level="Strong")[0] res = [res[i][default_float_name] for i in range(upsert_nb)] if not (res == float_values1.to_list() or res == float_values2.to_list()): assert False @@ -1761,7 +1910,8 @@ class TestUpsertValid(TestcaseBase): data = cf.gen_default_list_data(upsert_nb, start=i * step) collection_w.upsert(data) # load - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the result res = collection_w.query(expr="", output_fields=["count(*)"])[0] @@ -1884,7 +2034,8 @@ class TestUpsertInvalid(TestcaseBase): collection_w = self.init_collection_wrap(name=c_name) error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, expected: " "['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_upsert_pk_type_invalid(self): @@ -1899,7 +2050,8 @@ class TestUpsertInvalid(TestcaseBase): cf.gen_vectors(2, ct.default_dim)] error = {ct.err_code: 1, ct.err_msg: "The data type of field int64 doesn't match, " "expected: INT64, got VARCHAR"} - collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_upsert_data_unmatch(self): @@ -1915,7 +2067,8 @@ class TestUpsertInvalid(TestcaseBase): data = [1, "a", 2.0, vector] error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert(data=[data], check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert( + data=[data], check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("vector", [[], [1.0, 2.0], "a", 1.0, None]) @@ -1931,7 +2084,8 @@ class TestUpsertInvalid(TestcaseBase): data = [2.0, "a", vector] error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert(data=[data], check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert( + data=[data], check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [120, 129, 200]) @@ -1942,10 +2096,13 @@ class TestUpsertInvalid(TestcaseBase): 2. upsert with mismatched dim expected: raise exception """ - collection_w = self.init_collection_general(pre_upsert, True, is_binary=True)[0] + collection_w = self.init_collection_general( + pre_upsert, True, is_binary=True)[0] data = cf.gen_default_binary_dataframe_data(dim=dim)[0] - error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} - collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} + collection_w.upsert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [127, 129, 200]) @@ -1958,8 +2115,10 @@ class TestUpsertInvalid(TestcaseBase): """ collection_w = self.init_collection_general(pre_upsert, True)[0] data = cf.gen_default_data_for_upsert(dim=dim)[0] - error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} - collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} + collection_w.upsert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("partition_name", ct.get_invalid_strs[7:13]) @@ -2024,12 +2183,15 @@ class TestUpsertInvalid(TestcaseBase): 2. upsert data no pk expected: raise exception """ - collection_w = self.init_collection_general(pre_upsert, auto_id=True, is_index=False)[0] - error = {ct.err_code: 1, ct.err_msg: "Upsert don't support autoid == true"} + collection_w = self.init_collection_general( + pre_upsert, auto_id=True, is_index=False)[0] + error = {ct.err_code: 1, + ct.err_msg: "Upsert don't support autoid == true"} float_vec_values = cf.gen_vectors(ct.default_nb, ct.default_dim) data = [[np.float32(i) for i in range(ct.default_nb)], [str(i) for i in range(ct.default_nb)], float_vec_values] - collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert( + data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("default_value", [[], None]) @@ -2044,7 +2206,8 @@ class TestUpsertInvalid(TestcaseBase): schema = cf.gen_collection_schema(fields) collection_w = self.init_collection_wrap(schema=schema) vectors = cf.gen_vectors(ct.default_nb, ct.default_dim) - data = [{"int64": 1, "float_vector": vectors[1], "varchar": default_value, "float": np.float32(1.0)}] + data = [{"int64": 1, "float_vector": vectors[1], + "varchar": default_value, "float": np.float32(1.0)}] collection_w.upsert(data, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "Field varchar don't match in entities[0]"}) diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 98f57a929a..8841c52632 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -1,3 +1,14 @@ +import utils.util_pymilvus as ut +from utils.util_log import test_log as log +from common.common_type import CaseLabel, CheckTasks +from common import common_type as ct +from common import common_func as cf +from common.code_mapping import CollectionErrorMessage as clem +from common.code_mapping import ConnectionErrorMessage as cem +from base.client_base import TestcaseBase +from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY +import threading +from pymilvus import DefaultConfig from datetime import datetime import time @@ -6,18 +17,7 @@ import random import numpy as np import pandas as pd pd.set_option("expand_frame_repr", False) -from pymilvus import DefaultConfig -import threading -from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY -from base.client_base import TestcaseBase -from common.code_mapping import ConnectionErrorMessage as cem -from common.code_mapping import CollectionErrorMessage as clem -from common import common_func as cf -from common import common_type as ct -from common.common_type import CaseLabel, CheckTasks -from utils.util_log import test_log as log -import utils.util_pymilvus as ut prefix = "query" exp_res = "exp_res" @@ -27,8 +27,10 @@ default_mix_expr = "int64 >= 0 && varchar >= \"0\"" default_expr = f'{ct.default_int64_field_name} >= 0' default_invalid_expr = "varchar >= 0" default_string_term_expr = f'{ct.default_string_field_name} in [\"0\", \"1\"]' -default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}} -binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}} +default_index_params = {"index_type": "IVF_SQ8", + "metric_type": "L2", "params": {"nlist": 64}} +binary_index_params = {"index_type": "BIN_IVF_FLAT", + "metric_type": "JACCARD", "params": {"nlist": 64}} default_entities = ut.gen_entities(ut.default_nb, is_normal=True) default_pos = 5 @@ -59,10 +61,12 @@ class TestQueryParams(TestcaseBase): method: query with invalid term expr expected: raise exception """ - collection_w, entities = self.init_collection_general(prefix, insert_data=True, nb=10)[0:2] + collection_w, entities = self.init_collection_general( + prefix, insert_data=True, nb=10)[0:2] term_expr = f'{default_int_field_name} in {entities[:default_pos]}' error = {ct.err_code: 1, ct.err_msg: "unexpected token Identifier"} - collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) def test_query(self, enable_dynamic_field): @@ -80,13 +84,16 @@ class TestQueryParams(TestcaseBase): for vector in vectors[0]: vector = vector[ct.default_int64_field_name] int_values.append(vector) - res = [{ct.default_int64_field_name: int_values[i]} for i in range(pos)] + res = [{ct.default_int64_field_name: int_values[i]} + for i in range(pos)] else: - int_values = vectors[0][ct.default_int64_field_name].values.tolist() + int_values = vectors[0][ct.default_int64_field_name].values.tolist( + ) res = vectors[0].iloc[0:pos, :1].to_dict('records') term_expr = f'{ct.default_int64_field_name} in {int_values[:pos]}' - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query( + term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) def test_query_no_collection(self): @@ -100,15 +107,15 @@ class TestQueryParams(TestcaseBase): # 1. initialize without data collection_w = self.init_collection_general(prefix)[0] # 2. Drop collection - log.info("test_query_no_collection: drop collection %s" % collection_w.name) + log.info("test_query_no_collection: drop collection %s" % + collection_w.name) collection_w.drop() # 3. Search without collection log.info("test_query_no_collection: query without collection ") collection_w.query(default_term_expr, check_task=CheckTasks.err_res, check_items={"err_code": 1, - "err_msg": "DescribeCollection failed: " - "can't find collection: %s" % collection_w.name}) + "err_msg": "collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_query_empty_collection(self): @@ -119,7 +126,8 @@ class TestQueryParams(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() res, _ = collection_w.query(default_term_expr) assert len(res) == 0 @@ -141,14 +149,16 @@ class TestQueryParams(TestcaseBase): ids = insert_res[1].primary_keys pos = 5 res = df.iloc[:pos, :1].to_dict('records') - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # query with all primary keys term_expr_1 = f'{ct.default_int64_field_name} in {ids[:pos]}' for i in range(5): res[i][ct.default_int64_field_name] = ids[i] - self.collection_wrap.query(term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query( + term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) # query with part primary keys term_expr_2 = f'{ct.default_int64_field_name} in {[ids[0], 0]}' @@ -166,7 +176,8 @@ class TestQueryParams(TestcaseBase): expected: query results are de-duplicated """ nb = ct.default_nb - collection_w, insert_data, _, _ = self.init_collection_general(prefix, True, nb, dim=dim)[0:4] + collection_w, insert_data, _, _ = self.init_collection_general( + prefix, True, nb, dim=dim)[0:4] # insert dup data multi times for i in range(dup_times): collection_w.insert(insert_data[0]) @@ -185,12 +196,14 @@ class TestQueryParams(TestcaseBase): expected: query result is empty """ schema = cf.gen_default_collection_schema(auto_id=True) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_default_dataframe_data(ct.default_nb) df.drop(ct.default_int64_field_name, axis=1, inplace=True) mutation_res, _ = collection_w.insert(data=df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() term_expr = f'{ct.default_int64_field_name} in [0, 1, 2]' res, _ = collection_w.query(term_expr) @@ -203,9 +216,11 @@ class TestQueryParams(TestcaseBase): method: query with expr None expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"} - collection_w.query(None, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + None, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_non_string_expr(self): @@ -214,11 +229,13 @@ class TestQueryParams(TestcaseBase): method: query with non-string expr, eg 1, [] .. expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] exprs = [1, 2., [], {}, ()] error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"} for expr in exprs: - collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_invalid_string(self): @@ -227,11 +244,13 @@ class TestQueryParams(TestcaseBase): method: query with invalid string expr expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] error = {ct.err_code: 1, ct.err_msg: "Invalid expression!"} exprs = ["12-s", "中文", "a", " "] for expr in exprs: - collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.skip(reason="repeat with test_query, waiting for other expr") @@ -241,9 +260,11 @@ class TestQueryParams(TestcaseBase): method: query with TermExpr expected: query result is correct """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :1].to_dict('records') - collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query( + default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_not_existed_field(self): @@ -255,7 +276,8 @@ class TestQueryParams(TestcaseBase): collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) term_expr = 'field in [1, 2]' error = {ct.err_code: 1, ct.err_msg: "fieldName(field) not found"} - collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_primary_fields(self): @@ -273,12 +295,14 @@ class TestQueryParams(TestcaseBase): ct.default_float_field_name: pd.Series(data=[np.float32(i) for i in range(ct.default_nb)], dtype="float32"), ct.default_double_field_name: pd.Series(data=[np.double(i) for i in range(ct.default_nb)], dtype="double"), ct.default_string_field_name: pd.Series(data=[str(i) for i in range(ct.default_nb)], dtype="string"), - ct.default_float_vec_field_name: cf.gen_vectors(ct.default_nb, ct.default_dim) + ct.default_float_vec_field_name: cf.gen_vectors( + ct.default_nb, ct.default_dim) }) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # query by non_primary non_vector scalar field @@ -309,29 +333,36 @@ class TestQueryParams(TestcaseBase): """ self._connect() df = cf.gen_default_dataframe_data() - bool_values = pd.Series(data=[True if i % 2 == 0 else False for i in range(ct.default_nb)], dtype="bool") + bool_values = pd.Series( + data=[True if i % 2 == 0 else False for i in range(ct.default_nb)], dtype="bool") df.insert(2, ct.default_bool_field_name, bool_values) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # output bool field - res, _ = self.collection_wrap.query(default_term_expr, output_fields=[ct.default_bool_field_name]) - assert set(res[0].keys()) == {ct.default_int64_field_name, ct.default_bool_field_name} + res, _ = self.collection_wrap.query(default_term_expr, output_fields=[ + ct.default_bool_field_name]) + assert set(res[0].keys()) == { + ct.default_int64_field_name, ct.default_bool_field_name} # not support filter bool field with expr 'bool in [0/ 1]' not_support_expr = f'{ct.default_bool_field_name} in [0]' - error = {ct.err_code: 1, ct.err_msg: 'error: value \"0\" in list cannot be casted to Bool'} + error = {ct.err_code: 1, + ct.err_msg: 'error: value \"0\" in list cannot be casted to Bool'} self.collection_wrap.query(not_support_expr, output_fields=[ct.default_bool_field_name], check_task=CheckTasks.err_res, check_items=error) # filter bool field by bool term expr for bool_value in [True, False]: - exprs = [f'{ct.default_bool_field_name} in [{bool_value}]', f'{ct.default_bool_field_name} == {bool_value}'] + exprs = [f'{ct.default_bool_field_name} in [{bool_value}]', + f'{ct.default_bool_field_name} == {bool_value}'] for expr in exprs: - res, _ = self.collection_wrap.query(expr, output_fields=[ct.default_bool_field_name]) + res, _ = self.collection_wrap.query( + expr, output_fields=[ct.default_bool_field_name]) assert len(res) == ct.default_nb / 2 for _r in res: assert _r[ct.default_bool_field_name] == bool_value @@ -347,7 +378,8 @@ class TestQueryParams(TestcaseBase): self._connect() # construct collection from dataFrame according to [int64, float, int8, float_vec] df = cf.gen_default_dataframe_data() - int8_values = pd.Series(data=[np.int8(i) for i in range(ct.default_nb)], dtype="int8") + int8_values = pd.Series(data=[np.int8(i) + for i in range(ct.default_nb)], dtype="int8") df.insert(2, ct.default_int8_field_name, int8_values) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) @@ -359,7 +391,8 @@ class TestQueryParams(TestcaseBase): # int8 range [-128, 127] so when nb=1200, there are many repeated int8 values equal to 0 for i in range(0, ct.default_nb, 256): res.extend(df.iloc[i:i + 1, :-2].to_dict('records')) - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() self.collection_wrap.query(term_expr, output_fields=["float", "int64", "int8", "varchar"], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -380,8 +413,7 @@ class TestQueryParams(TestcaseBase): # 1. initialize with data nb = 1000 collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -410,14 +442,19 @@ class TestQueryParams(TestcaseBase): method: query with wrong keyword term expr expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] expr_1 = f'{ct.default_int64_field_name} inn [1, 2]' - error_1 = {ct.err_code: 1, ct.err_msg: f'unexpected token Identifier("inn")'} - collection_w.query(expr_1, check_task=CheckTasks.err_res, check_items=error_1) + error_1 = {ct.err_code: 1, + ct.err_msg: f'unexpected token Identifier("inn")'} + collection_w.query( + expr_1, check_task=CheckTasks.err_res, check_items=error_1) expr_3 = f'{ct.default_int64_field_name} in not [1, 2]' - error_3 = {ct.err_code: 1, ct.err_msg: 'right operand of the InExpr must be array'} - collection_w.query(expr_3, check_task=CheckTasks.err_res, check_items=error_3) + error_3 = {ct.err_code: 1, + ct.err_msg: 'right operand of the InExpr must be array'} + collection_w.query( + expr_3, check_task=CheckTasks.err_res, check_items=error_3) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("field", [ct.default_int64_field_name, ct.default_float_field_name]) @@ -432,7 +469,8 @@ class TestQueryParams(TestcaseBase): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() values = df[field].tolist() pos = 100 @@ -454,12 +492,14 @@ class TestQueryParams(TestcaseBase): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() int64_values = df[ct.default_int64_field_name].tolist() term_expr = f'{ct.default_int64_field_name} not in {int64_values[pos:]}' res = df.iloc[:pos, :1].to_dict('records') - self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query( + term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) def test_query_expr_random_values(self): @@ -474,14 +514,16 @@ class TestQueryParams(TestcaseBase): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == 100 - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # random_values = [random.randint(0, ct.default_nb) for _ in range(4)] random_values = [0, 2, 4, 3] term_expr = f'{ct.default_int64_field_name} in {random_values}' res = df.iloc[random_values, :1].to_dict('records') - self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query( + term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_not_in_random(self): @@ -496,7 +538,8 @@ class TestQueryParams(TestcaseBase): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == 50 - self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() random_values = [i for i in range(10, 50)] @@ -504,7 +547,8 @@ class TestQueryParams(TestcaseBase): random.shuffle(random_values) term_expr = f'{ct.default_int64_field_name} not in {random_values}' res = df.iloc[:10, :1].to_dict('records') - self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query( + term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_array_term(self): @@ -516,10 +560,13 @@ class TestQueryParams(TestcaseBase): exprs = [f'{ct.default_int64_field_name} in 1', f'{ct.default_int64_field_name} in "in"', f'{ct.default_int64_field_name} in (mn)'] - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] - error = {ct.err_code: 1, ct.err_msg: "right operand of the InExpr must be array"} + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] + error = {ct.err_code: 1, + ct.err_msg: "right operand of the InExpr must be array"} for expr in exprs: - collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_empty_term_array(self): @@ -529,7 +576,8 @@ class TestQueryParams(TestcaseBase): expected: empty result """ term_expr = f'{ct.default_int64_field_name} in []' - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] res, _ = collection_w.query(term_expr) assert len(res) == 0 @@ -546,7 +594,8 @@ class TestQueryParams(TestcaseBase): error = {ct.err_code: 1, ct.err_msg: "type mismatch"} for values in int_values: term_expr = f'{ct.default_int64_field_name} in {values}' - collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_constant_array_term(self): @@ -555,12 +604,14 @@ class TestQueryParams(TestcaseBase): method: query with non-constant array expr expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] constants = [[1], (), {}] error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"} for constant in constants: term_expr = f'{ct.default_int64_field_name} in [{constant}]' - collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) @@ -571,13 +622,15 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = 99 for i in range(ct.default_nb): - array[i][json_field] = {"number": i, "list": [m for m in range(i, i + limit)]} + array[i][json_field] = {"number": i, + "list": [m for m in range(i, i + limit)]} collection_w.insert(array) @@ -596,7 +649,8 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=True)[0] # 2. insert data limit = ct.default_nb // 4 @@ -605,7 +659,8 @@ class TestQueryParams(TestcaseBase): data = { ct.default_int64_field_name: i, ct.default_json_field_name: [str(m) for m in range(i, i + limit)], - ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[0] + ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[ + 0] } array.append(data) collection_w.insert(array) @@ -625,13 +680,15 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = ct.default_nb // 3 for i in range(ct.default_nb): - array[i][ct.default_json_field_name] = {"number": i, "list": [m for m in range(i, i + limit)]} + array[i][ct.default_json_field_name] = { + "number": i, "list": [m for m in range(i, i + limit)]} collection_w.insert(array) @@ -651,7 +708,8 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() @@ -659,11 +717,15 @@ class TestQueryParams(TestcaseBase): for i in range(ct.default_nb): content = { "listInt": [m for m in range(i, i + limit)], # test for int - "listStr": [str(m) for m in range(i, i + limit)], # test for string - "listFlt": [m * 1.0 for m in range(i, i + limit)], # test for float + # test for string + "listStr": [str(m) for m in range(i, i + limit)], + # test for float + "listFlt": [m * 1.0 for m in range(i, i + limit)], "listBool": [bool(i % 2)], # test for bool - "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], # test for list - "listMix": [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data + # test for list + "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], + # test for mixed data + "listMix": [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] } array[i][ct.default_json_field_name] = content @@ -717,18 +779,24 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=True)[0] # 2. insert data array = cf.gen_default_rows_data(with_json=False) limit = 10 for i in range(ct.default_nb): - array[i]["listInt"] = [m for m in range(i, i + limit)] # test for int - array[i]["listStr"] = [str(m) for m in range(i, i + limit)] # test for string - array[i]["listFlt"] = [m * 1.0 for m in range(i, i + limit)] # test for float + array[i]["listInt"] = [m for m in range( + i, i + limit)] # test for int + array[i]["listStr"] = [str(m) for m in range( + i, i + limit)] # test for string + array[i]["listFlt"] = [ + m * 1.0 for m in range(i, i + limit)] # test for float array[i]["listBool"] = [bool(i % 2)] # test for bool - array[i]["listList"] = [[i, str(i + 1)], [i * 1.0, i + 1]] # test for list - array[i]["listMix"] = [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data + array[i]["listList"] = [ + [i, str(i + 1)], [i * 1.0, i + 1]] # test for list + array[i]["listMix"] = [i, i * 1.1, + str(i), bool(i % 2), [i, str(i)]] # test for mixed data collection_w.insert(array) @@ -781,7 +849,8 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() @@ -789,11 +858,15 @@ class TestQueryParams(TestcaseBase): for i in range(ct.default_nb): content = { "listInt": [m for m in range(i, i + limit)], # test for int - "listStr": [str(m) for m in range(i, i + limit)], # test for string - "listFlt": [m * 1.0 for m in range(i, i + limit)], # test for float + # test for string + "listStr": [str(m) for m in range(i, i + limit)], + # test for float + "listFlt": [m * 1.0 for m in range(i, i + limit)], "listBool": [bool(i % 2)], # test for bool - "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], # test for list - "listMix": [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data + # test for list + "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], + # test for mixed data + "listMix": [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] } array[i][ct.default_json_field_name] = content @@ -848,18 +921,24 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=True)[0] # 2. insert data array = cf.gen_default_rows_data(with_json=False) limit = 10 for i in range(ct.default_nb): - array[i]["listInt"] = [m for m in range(i, i + limit)] # test for int - array[i]["listStr"] = [str(m) for m in range(i, i + limit)] # test for string - array[i]["listFlt"] = [m * 1.0 for m in range(i, i + limit)] # test for float + array[i]["listInt"] = [m for m in range( + i, i + limit)] # test for int + array[i]["listStr"] = [str(m) for m in range( + i, i + limit)] # test for string + array[i]["listFlt"] = [ + m * 1.0 for m in range(i, i + limit)] # test for float array[i]["listBool"] = [bool(i % 2)] # test for bool - array[i]["listList"] = [[i, str(i + 1)], [i * 1.0, i + 1]] # test for list - array[i]["listMix"] = [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data + array[i]["listList"] = [ + [i, str(i + 1)], [i * 1.0, i + 1]] # test for list + array[i]["listMix"] = [i, i * 1.1, + str(i), bool(i % 2), [i, str(i)]] # test for mixed data collection_w.insert(array) @@ -912,12 +991,14 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() for i in range(ct.default_nb): - array[i][json_field] = {"list": [[i, i + 1], [i, i + 2], [i, i + 3]]} + array[i][json_field] = { + "list": [[i, i + 1], [i, i + 2], [i, i + 3]]} collection_w.insert(array) @@ -938,7 +1019,7 @@ class TestQueryParams(TestcaseBase): if request.param == [1, "2", 3]: pytest.skip('[1, "2", 3] is valid type for list') yield request.param - + @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("expr_prefix", ["json_contains_any", "JSON_CONTAINS_ANY", "json_contains_all", "JSON_CONTAINS_ALL"]) @@ -949,12 +1030,14 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() for i in range(ct.default_nb): - array[i][json_field] = {"number": i, "list": [m for m in range(i, i + 10)]} + array[i][json_field] = {"number": i, + "list": [m for m in range(i, i + 10)]} collection_w.insert(array) @@ -963,7 +1046,8 @@ class TestQueryParams(TestcaseBase): expression = f"{expr_prefix}({json_field}['list'], {get_not_list})" error = {ct.err_code: 1, ct.err_msg: f"cannot parse expression {expression}, error: " f"error: {expr_prefix} operation element must be an array"} - collection_w.query(expression, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + expression, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) @@ -974,13 +1058,15 @@ class TestQueryParams(TestcaseBase): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = ct.default_nb // 3 for i in range(ct.default_nb): - array[i][json_field] = {"number": i, "list": [m for m in range(i, i + limit)]} + array[i][json_field] = {"number": i, + "list": [m for m in range(i, i + limit)]} collection_w.insert(array) @@ -1002,11 +1088,14 @@ class TestQueryParams(TestcaseBase): collection_w = self.init_collection_general(prefix, True)[0] # 2. query with no limit and no offset - error = {ct.err_code: 1, ct.err_msg: "empty expression should be used with limit"} - collection_w.query("", check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "empty expression should be used with limit"} + collection_w.query( + "", check_task=CheckTasks.err_res, check_items=error) # 3. query with offset but no limit - collection_w.query("", offset=1, check_task=CheckTasks.err_res, check_items=error) + collection_w.query( + "", offset=1, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_empty(self): @@ -1034,13 +1123,15 @@ class TestQueryParams(TestcaseBase): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, auto_id=auto_id)[0:4] exp_ids, res = insert_ids[:limit], [] for ids in exp_ids: res.append({ct.default_int64_field_name: ids}) # 2. query with limit - collection_w.query("", limit=limit, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query( + "", limit=limit, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_empty_pk_string(self): @@ -1051,18 +1142,22 @@ class TestQueryParams(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, primary_field=ct.default_string_field_name)[0:4] + self.init_collection_general( + prefix, True, primary_field=ct.default_string_field_name)[0:4] # string field is sorted by lexicographical order - exp_ids, res = ['0', '1', '10', '100', '1000', '1001', '1002', '1003', '1004', '1005'], [] + exp_ids, res = ['0', '1', '10', '100', '1000', + '1001', '1002', '1003', '1004', '1005'], [] for ids in exp_ids: res.append({ct.default_string_field_name: ids}) # 2. query with limit - collection_w.query("", limit=ct.default_limit, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query("", limit=ct.default_limit, + check_task=CheckTasks.check_query_results, check_items={exp_res: res}) # 2. query with limit + offset res = res[5:] - collection_w.query("", limit=5, offset=5, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query( + "", limit=5, offset=5, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("offset", [100, 1000]) @@ -1075,7 +1170,8 @@ class TestQueryParams(TestcaseBase): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, auto_id=auto_id)[0:4] exp_ids, res = insert_ids[:limit + offset][offset:], [] for ids in exp_ids: res.append({ct.default_int64_field_name: ids}) @@ -1102,13 +1198,15 @@ class TestQueryParams(TestcaseBase): float_value = [np.float32(i) for i in unordered_ids] string_value = [str(i) for i in unordered_ids] vector_value = cf.gen_vectors(nb=ct.default_nb, dim=ct.default_dim) - collection_w.insert([unordered_ids, float_value, string_value, vector_value]) + collection_w.insert([unordered_ids, float_value, + string_value, vector_value]) collection_w.load() # 3. query with empty expr and check the result exp_ids, res = sorted(unordered_ids)[:limit], [] for ids in exp_ids: - res.append({ct.default_int64_field_name: ids, ct.default_string_field_name: str(ids)}) + res.append({ct.default_int64_field_name: ids, + ct.default_string_field_name: str(ids)}) collection_w.query("", limit=limit, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1116,7 +1214,8 @@ class TestQueryParams(TestcaseBase): # 4. query with pagination exp_ids, res = sorted(unordered_ids)[:limit + offset][offset:], [] for ids in exp_ids: - res.append({ct.default_int64_field_name: ids, ct.default_string_field_name: str(ids)}) + res.append({ct.default_int64_field_name: ids, + ct.default_string_field_name: str(ids)}) collection_w.query("", limit=limit, offset=offset, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1132,16 +1231,22 @@ class TestQueryParams(TestcaseBase): collection_w = self.init_collection_general(prefix, True)[0] # 2. query with limit > 16384 - error = {ct.err_code: 1, ct.err_msg: "invalid max query result window, (offset+limit) should be in range [1, 16384]"} - collection_w.query("", limit=16385, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "invalid max query result window, (offset+limit) should be in range [1, 16384]"} + collection_w.query( + "", limit=16385, check_task=CheckTasks.err_res, check_items=error) # 3. query with offset + limit > 16384 - collection_w.query("", limit=1, offset=16384, check_task=CheckTasks.err_res, check_items=error) - collection_w.query("", limit=16384, offset=1, check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", limit=1, offset=16384, + check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", limit=16384, offset=1, + check_task=CheckTasks.err_res, check_items=error) # 4. query with limit < 0 - error = {ct.err_code: 1, ct.err_msg: "invalid max query result window, offset [-1] is invalid, should be gte than 0"} - collection_w.query("", limit=2, offset=-1, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "invalid max query result window, offset [-1] is invalid, should be gte than 0"} + collection_w.query("", limit=2, offset=-1, + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("expression", cf.gen_integer_overflow_expressions()) @@ -1152,13 +1257,16 @@ class TestQueryParams(TestcaseBase): expected: """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + prefix, is_all_data_type=True)[0] start = ct.default_nb // 2 _vectors = cf.gen_dataframe_all_data_type(start=start) # increase the value to cover the int range - _vectors["int16"] = pd.Series(data=[np.int16(i*40) for i in range(start, start + ct.default_nb)], dtype="int16") - _vectors["int32"] = pd.Series(data=[np.int32(i*2200000) for i in range(start, start + ct.default_nb)], dtype="int32") + _vectors["int16"] = pd.Series(data=[np.int16( + i*40) for i in range(start, start + ct.default_nb)], dtype="int16") + _vectors["int32"] = pd.Series(data=[np.int32( + i*2200000) for i in range(start, start + ct.default_nb)], dtype="int32") insert_ids = collection_w.insert(_vectors)[0].primary_keys # filter result with expression in collection @@ -1186,7 +1294,8 @@ class TestQueryParams(TestcaseBase): collection_w = self.init_collection_general(prefix, insert_data=True, enable_dynamic_field=enable_dynamic_field)[0] for fields in [None, []]: - res, _ = collection_w.query(default_term_expr, output_fields=fields) + res, _ = collection_w.query( + default_term_expr, output_fields=fields) assert res[0].keys() == {ct.default_int64_field_name} @pytest.mark.tags(CaseLabel.L0) @@ -1197,10 +1306,11 @@ class TestQueryParams(TestcaseBase): expected: return one field """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, - enable_dynamic_field= - enable_dynamic_field)[0:2] - res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name]) - assert set(res[0].keys()) == {ct.default_int64_field_name, ct.default_float_field_name} + enable_dynamic_field=enable_dynamic_field)[0:2] + res, _ = collection_w.query(default_term_expr, output_fields=[ + ct.default_float_field_name]) + assert set(res[0].keys()) == { + ct.default_int64_field_name, ct.default_float_field_name} @pytest.mark.tags(CaseLabel.L1) def test_query_output_all_fields(self, enable_dynamic_field, random_primary_key): @@ -1223,7 +1333,8 @@ class TestQueryParams(TestcaseBase): else: res = [] for id in range(2): - num = df[0][df[0][ct.default_int64_field_name] == id].index.to_list()[0] + num = df[0][df[0][ct.default_int64_field_name] == id].index.to_list()[ + 0] res.append(df[0].iloc[num].to_dict()) log.info(res) collection_w.load() @@ -1239,13 +1350,17 @@ class TestQueryParams(TestcaseBase): method: specify vec field as output field expected: return primary field and vec field """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - fields = [[ct.default_float_vec_field_name], [ct.default_int64_field_name, ct.default_float_vec_field_name]] - res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + fields = [[ct.default_float_vec_field_name], [ + ct.default_int64_field_name, ct.default_float_vec_field_name]] + res = df.loc[:1, [ct.default_int64_field_name, + ct.default_float_vec_field_name]].to_dict('records') + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() for output_fields in fields: collection_w.query(default_term_expr, output_fields=output_fields, @@ -1261,16 +1376,20 @@ class TestQueryParams(TestcaseBase): method: query with one output_field (wildcard) expected: query success """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields) + output_fields = cf.get_wildcard_output_field_names( + collection_w, wildcard_output_fields) output_fields.append(default_int_field_name) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() with_vec = True if ct.default_float_vec_field_name in output_fields else False - actual_res = collection_w.query(default_term_expr, output_fields=wildcard_output_fields)[0] + actual_res = collection_w.query( + default_term_expr, output_fields=wildcard_output_fields)[0] assert set(actual_res[0].keys()) == set(output_fields) @pytest.mark.tags(CaseLabel.L1) @@ -1285,17 +1404,20 @@ class TestQueryParams(TestcaseBase): """ # init collection with two float vector fields schema = cf.gen_schema_multi_vector_fields(vec_fields) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_vec_fields(vec_fields=vec_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb # query with two vec output_fields - output_fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] + output_fields = [ct.default_int64_field_name, + ct.default_float_vec_field_name] for vec_field in vec_fields: output_fields.append(vec_field.name) res = df.loc[:1, output_fields].to_dict('records') - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.check_query_results, @@ -1314,17 +1436,20 @@ class TestQueryParams(TestcaseBase): """ # init collection with two float vector fields schema = cf.gen_schema_multi_vector_fields(vec_fields) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_vec_fields(vec_fields=vec_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb # query with two vec output_fields - output_fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] + output_fields = [ct.default_int64_field_name, + ct.default_float_vec_field_name] for vec_field in vec_fields: output_fields.append(vec_field.name) res = df.loc[:1, output_fields].to_dict('records') - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.check_query_results, @@ -1342,10 +1467,13 @@ class TestQueryParams(TestcaseBase): method: specify binary vec field as output field expected: return primary field and binary vec field """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2] - fields = [[ct.default_binary_vec_field_name], [ct.default_int64_field_name, ct.default_binary_vec_field_name]] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True, is_binary=True)[0:2] + fields = [[ct.default_binary_vec_field_name], [ + ct.default_int64_field_name, ct.default_binary_vec_field_name]] for output_fields in fields: - res, _ = collection_w.query(default_term_expr, output_fields=output_fields) + res, _ = collection_w.query( + default_term_expr, output_fields=output_fields) assert res[0].keys() == set(fields[-1]) @pytest.mark.tags(CaseLabel.L1) @@ -1355,8 +1483,10 @@ class TestQueryParams(TestcaseBase): method: specify int64 primary field as output field expected: return int64 field """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] - res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name]) + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] + res, _ = collection_w.query(default_term_expr, output_fields=[ + ct.default_int64_field_name]) assert res[0].keys() == {ct.default_int64_field_name} @pytest.mark.tags(CaseLabel.L2) @@ -1366,7 +1496,8 @@ class TestQueryParams(TestcaseBase): method: query with not existed output field expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] error = {ct.err_code: 1, ct.err_msg: 'Field int not exist'} output_fields = [["int"], [ct.default_int64_field_name, "int"]] for fields in output_fields: @@ -1381,9 +1512,11 @@ class TestQueryParams(TestcaseBase): method: query with invalid field fields expected: raise exception """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] output_fields = ["12-s", 1, [1, "2", 3], (1,), {1: 1}] - error = {ct.err_code: 0, ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'} + error = {ct.err_code: 0, + ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'} for fields in output_fields: collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res, check_items=error) @@ -1398,7 +1531,8 @@ class TestQueryParams(TestcaseBase): """ # init collection with fields: int64, float, float_vec, float_vector1 # collection_w, df = self.init_multi_fields_collection_wrap(cf.gen_unique_str(prefix)) - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] df = vectors[0] # query with wildcard all fields @@ -1416,12 +1550,14 @@ class TestQueryParams(TestcaseBase): expected: verify query result """ # init collection with fields: int64, float, float_vec - collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True, is_index=False)[0:2] df = vectors[0] # query with output_fields=["*", float_vector) res = df.iloc[:2].to_dict('records') - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=["*", ct.default_float_vec_field_name], check_task=CheckTasks.check_query_results, @@ -1436,11 +1572,13 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ # init collection with fields: int64, float, float_vec - collection_w = self.init_collection_general(prefix, insert_data=True, nb=100)[0] + collection_w = self.init_collection_general( + prefix, insert_data=True, nb=100)[0] collection_w.load() # query with invalid output_fields - error = {ct.err_code: 1, ct.err_msg: f"Field {output_fields[-1]} not exist"} + error = {ct.err_code: 1, + ct.err_msg: f"Field {output_fields[-1]} not exist"} collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error) @@ -1451,12 +1589,14 @@ class TestQueryParams(TestcaseBase): method: create a partition and query expected: verify query result """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() res = df.iloc[:2, :1].to_dict('records') collection_w.query(default_term_expr, partition_names=[partition_w.name], @@ -1469,12 +1609,14 @@ class TestQueryParams(TestcaseBase): method: query on partition and no loading expected: raise exception """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert partition_w.num_entities == ct.default_nb - error = {ct.err_code: 1, ct.err_msg: f'collection {collection_w.name} was not loaded into memory'} + error = {ct.err_code: 1, + ct.err_msg: f'collection {collection_w.name} was not loaded into memory'} collection_w.query(default_term_expr, partition_names=[partition_w.name], check_task=CheckTasks.err_res, check_items=error) @@ -1485,7 +1627,8 @@ class TestQueryParams(TestcaseBase): method: query on default partition expected: verify query result """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :1].to_dict('records') collection_w.query(default_term_expr, partition_names=[ct.default_partition_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1499,7 +1642,8 @@ class TestQueryParams(TestcaseBase): """ # insert [0, half) into partition_w, [half, nb) into _default half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( + half) # query from empty partition_names term_expr = f'{ct.default_int64_field_name} in [0, {half}, {ct.default_nb}-1]' @@ -1514,12 +1658,15 @@ class TestQueryParams(TestcaseBase): method: query on an empty collection expected: empty query result """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) assert partition_w.is_empty - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() - res, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) + res, _ = collection_w.query( + default_term_expr, partition_names=[partition_w.name]) assert len(res) == 0 @pytest.mark.tags(CaseLabel.L2) @@ -1530,10 +1677,12 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() partition_names = cf.gen_unique_str() - error = {ct.err_code: 1, ct.err_msg: f'PartitionName: {partition_names} not found'} + error = {ct.err_code: 1, + ct.err_msg: f'PartitionName: {partition_names} not found'} collection_w.query(default_term_expr, partition_names=[partition_names], check_task=CheckTasks.err_res, check_items=error) @@ -1618,7 +1767,8 @@ class TestQueryParams(TestcaseBase): expected: query successfully and verify query result """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[offset: pos + offset]}' @@ -1661,7 +1811,8 @@ class TestQueryParams(TestcaseBase): """ # 1. initialize with data nb = 1000 - collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general( + prefix, True, nb)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -1687,12 +1838,14 @@ class TestQueryParams(TestcaseBase): method: create a partition and query with different offset expected: verify query result """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() res = df.iloc[:2, :1].to_dict('records') query_params = {"offset": offset, "limit": 10} @@ -1706,11 +1859,13 @@ class TestQueryParams(TestcaseBase): method: create a partition and query with pagination expected: verify query result """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() res = df.iloc[:2, :1].to_dict('records') query_params = {"offset": offset, "limit": 10} @@ -1725,7 +1880,8 @@ class TestQueryParams(TestcaseBase): compare the result with query without pagination params expected: query successfully """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[offset: pos + offset]}' @@ -1748,7 +1904,8 @@ class TestQueryParams(TestcaseBase): expected: return an empty list """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1764,7 +1921,8 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1782,7 +1940,8 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1801,7 +1960,8 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1819,7 +1979,8 @@ class TestQueryParams(TestcaseBase): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1841,7 +2002,8 @@ class TestQueryParams(TestcaseBase): upsert_nb = 1000 expr = f"int64 >= 0 && int64 <= {upsert_nb}" collection_w = self.init_collection_general(prefix, True)[0] - res1 = collection_w.query(expr, output_fields=[default_float_field_name])[0] + res1 = collection_w.query( + expr, output_fields=[default_float_field_name])[0] def do_upsert(): data = cf.gen_default_data_for_upsert(upsert_nb)[0] @@ -1849,7 +2011,8 @@ class TestQueryParams(TestcaseBase): t = threading.Thread(target=do_upsert, args=()) t.start() - res2 = collection_w.query(expr, output_fields=[default_float_field_name])[0] + res2 = collection_w.query( + expr, output_fields=[default_float_field_name])[0] t.join() assert [res1[i][default_float_field_name] for i in range(upsert_nb)] == \ [res2[i][default_float_field_name] for i in range(upsert_nb)] @@ -1871,13 +2034,16 @@ class TestQueryOperation(TestcaseBase): """ # init a collection with default connection - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) # remove default connection - self.connection_wrap.remove_connection(alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.remove_connection( + alias=DefaultConfig.DEFAULT_USING) # list connection to check - self.connection_wrap.list_connections(check_task=ct.CheckTasks.ccr, check_items={ct.list_content: []}) + self.connection_wrap.list_connections( + check_task=ct.CheckTasks.ccr, check_items={ct.list_content: []}) # query after remove default connection collection_w.query(default_term_expr, check_task=CheckTasks.err_res, @@ -1915,11 +2081,13 @@ class TestQueryOperation(TestcaseBase): """ # init a collection and insert data - collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = self.init_collection_general( + prefix, insert_data=True)[0:3] # query the first row of data check_vec = vectors[0].iloc[:, [0]][0:1].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]']) @@ -1936,7 +2104,8 @@ class TestQueryOperation(TestcaseBase): # query the first row of data check_vec = vectors[0].iloc[:, [0]][0:1].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_all_term_array(self): @@ -1947,15 +2116,18 @@ class TestQueryOperation(TestcaseBase): """ # init a collection and insert data - collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = self.init_collection_general( + prefix, insert_data=True)[0:3] # data preparation int_values = vectors[0][ct.default_int64_field_name].values.tolist() term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') + check_vec = vectors[0].iloc[:, [0]][0:len( + int_values)].to_dict('records') # query all array value - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_expr_half_term_array(self): @@ -1966,7 +2138,8 @@ class TestQueryOperation(TestcaseBase): """ half = ct.default_nb // 2 - collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half) + collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half( + half) int_values = df_default[ct.default_int64_field_name].values.tolist() term_expr = f'{ct.default_int64_field_name} in {int_values}' @@ -1980,7 +2153,8 @@ class TestQueryOperation(TestcaseBase): method: query with repeated array value expected: return hit entities, no repeated """ - collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = self.init_collection_general( + prefix, insert_data=True)[0:3] int_values = [0, 0, 0, 0] term_expr = f'{ct.default_int64_field_name} in {int_values}' res, _ = collection_w.query(term_expr) @@ -1995,12 +2169,15 @@ class TestQueryOperation(TestcaseBase): 2.query with dup term array expected: todo """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=100) df[ct.default_int64_field_name] = 0 mutation_res, _ = collection_w.insert(df) - assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist() - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist( + ) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() term_expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}' res = df.iloc[:, :2].to_dict('records') @@ -2026,8 +2203,10 @@ class TestQueryOperation(TestcaseBase): int_values = [0] term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + check_vec = vectors[0].iloc[:, [0]][0:len( + int_values)].to_dict('records') + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_after_search(self): @@ -2044,7 +2223,8 @@ class TestQueryOperation(TestcaseBase): self.init_collection_general(prefix, True, nb_old)[0:4] # 2. search for original data after load - vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)] + vectors_s = [[random.random() for _ in range(ct.default_dim)] + for _ in range(ct.default_nq)] collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name, ct.default_search_params, limit, "int64 >= 0", check_task=CheckTasks.check_search_results, @@ -2055,7 +2235,8 @@ class TestQueryOperation(TestcaseBase): term_expr = f'{ct.default_int64_field_name} in [0, 1]' check_vec = vectors[0].iloc[:, [0]][0:2].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_output_vec_field_after_index(self): @@ -2064,14 +2245,17 @@ class TestQueryOperation(TestcaseBase): method: create index and specify vec field as output field expected: return primary field and vec field """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=5000) collection_w.insert(df) assert collection_w.num_entities == 5000 fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] - res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') + res = df.loc[:1, [ct.default_int64_field_name, + ct.default_float_vec_field_name]].to_dict('records') collection_w.load() error = {ct.err_code: 1, ct.err_msg: 'not allowed'} collection_w.query(default_term_expr, output_fields=fields, @@ -2086,12 +2270,15 @@ class TestQueryOperation(TestcaseBase): expected: return primary field and vec field """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=False)[ - 0:2] - fields = [ct.default_int64_field_name, ct.default_binary_vec_field_name] - collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params) + 0:2] + fields = [ct.default_int64_field_name, + ct.default_binary_vec_field_name] + collection_w.create_index( + ct.default_binary_vec_field_name, binary_index_params) assert collection_w.has_index()[0] collection_w.load() - res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name]) + res, _ = collection_w.query(default_term_expr, output_fields=[ + ct.default_binary_vec_field_name]) assert res[0].keys() == set(fields) @pytest.mark.tags(CaseLabel.L2) @@ -2106,7 +2293,8 @@ class TestQueryOperation(TestcaseBase): self._connect() # init collection - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) # init partition partition_w = self.init_partition_wrap(collection_wrap=collection_w) @@ -2119,12 +2307,15 @@ class TestQueryOperation(TestcaseBase): assert collection_w.num_entities == ct.default_nb # load partition - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() # query twice - res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) - res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) + res_one, _ = collection_w.query( + default_term_expr, partition_names=[partition_w.name]) + res_two, _ = collection_w.query( + default_term_expr, partition_names=[partition_w.name]) assert res_one == res_two @pytest.mark.tags(CaseLabel.L2) @@ -2136,7 +2327,8 @@ class TestQueryOperation(TestcaseBase): expected: query result is empty """ half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( + half) term_expr = f'{ct.default_int64_field_name} in [{half}]' # half entity in _default partition rather than partition_w @@ -2152,11 +2344,13 @@ class TestQueryOperation(TestcaseBase): expected: query results from two partitions """ half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( + half) term_expr = f'{ct.default_int64_field_name} in [{half - 1}, {half}]' # half entity in _default, half-1 entity in partition_w - res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) + res, _ = collection_w.query(term_expr, partition_names=[ + ct.default_partition_name, partition_w.name]) assert len(res) == 2 @pytest.mark.tags(CaseLabel.L2) @@ -2168,11 +2362,13 @@ class TestQueryOperation(TestcaseBase): expected: query from two partitions and get single result """ half = ct.default_nb // 2 - collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half(half) + collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half( + half) term_expr = f'{ct.default_int64_field_name} in [{half}]' # half entity in _default - res, _ = collection_w.query(term_expr, partition_names=[ct.default_partition_name, partition_w.name]) + res, _ = collection_w.query(term_expr, partition_names=[ + ct.default_partition_name, partition_w.name]) assert len(res) == 1 assert res[0][ct.default_int64_field_name] == half @@ -2186,9 +2382,11 @@ class TestQueryOperation(TestcaseBase): 4.query expected: Data can be queried """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) # load collection - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() tmp_nb = 100 df = cf.gen_default_dataframe_data(tmp_nb) @@ -2258,7 +2456,8 @@ class TestQueryString(TestcaseBase): expected: query successfully """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :3].to_dict('records') output_fields = [default_float_field_name, default_string_field_name] collection_w.query(default_string_term_expr, output_fields=output_fields, @@ -2274,7 +2473,8 @@ class TestQueryString(TestcaseBase): """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2] - res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name]) + res, _ = collection_w.query(expression, output_fields=[ + ct.default_string_field_name]) assert res[0].keys() == {ct.default_string_field_name} @pytest.mark.tags(CaseLabel.L1) @@ -2301,7 +2501,8 @@ class TestQueryString(TestcaseBase): query with invalid expr expected: Raise exception """ - collection_w = self.init_collection_general(prefix, insert_data=True)[0] + collection_w = self.init_collection_general( + prefix, insert_data=True)[0] collection_w.query(expression, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "type mismatch"}) @@ -2313,11 +2514,13 @@ class TestQueryString(TestcaseBase): expected: verify query successfully """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=False)[ - 0:2] - collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params) + 0:2] + collection_w.create_index( + ct.default_binary_vec_field_name, binary_index_params) collection_w.load() assert collection_w.has_index()[0] - res, _ = collection_w.query(default_string_term_expr, output_fields=[ct.default_binary_vec_field_name]) + res, _ = collection_w.query(default_string_term_expr, output_fields=[ + ct.default_binary_vec_field_name]) assert len(res) == 2 @pytest.mark.tags(CaseLabel.L1) @@ -2331,7 +2534,8 @@ class TestQueryString(TestcaseBase): primary_field=ct.default_string_field_name)[0:2] res = vectors[0].iloc[:1, :3].to_dict('records') expression = 'varchar like "0%"' - output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, + default_float_field_name, default_string_field_name] collection_w.query(expression, output_fields=output_fields, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -2342,10 +2546,12 @@ class TestQueryString(TestcaseBase): method: specify string primary field, use invalid prefix string expr expected: raise error """ - collection_w = self.init_collection_general(prefix, insert_data=True)[0] + collection_w = self.init_collection_general( + prefix, insert_data=True)[0] expression = 'float like "0%"' collection_w.query(expression, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "like operation on non-string field is unsupported"} + check_items={ + ct.err_code: 1, ct.err_msg: "like operation on non-string field is unsupported"} ) @pytest.mark.tags(CaseLabel.L1) @@ -2356,10 +2562,12 @@ class TestQueryString(TestcaseBase): expected: verify query successfully """ collection_w = \ - self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] + self.init_collection_general( + prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] res = [] expression = 'float > int64' - output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, + default_float_field_name, default_string_field_name] collection_w.query(expression, output_fields=output_fields, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -2371,7 +2579,8 @@ class TestQueryString(TestcaseBase): expected: raise error """ collection_w = \ - self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] + self.init_collection_general( + prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] expression = 'varchar == int64' collection_w.query(expression, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: f' cannot parse expression:{expression}'}) @@ -2384,7 +2593,8 @@ class TestQueryString(TestcaseBase): method: multi threads insert, and query, compare queried data with original expected: verify data consistency """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) thread_num = 4 threads = [] primary_keys = [] @@ -2392,7 +2602,8 @@ class TestQueryString(TestcaseBase): # prepare original data for parallel insert for i in range(thread_num): - df = cf.gen_default_dataframe_data(ct.default_nb, start=i * ct.default_nb) + df = cf.gen_default_dataframe_data( + ct.default_nb, start=i * ct.default_nb) df_list.append(df) primary_key = df[ct.default_int64_field_name].values.tolist() primary_keys.append(primary_key) @@ -2412,7 +2623,8 @@ class TestQueryString(TestcaseBase): assert collection_w.num_entities == ct.default_nb * thread_num # Check data consistency after parallel insert - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() df_dict_list = [] for df in df_list: @@ -2436,8 +2648,10 @@ class TestQueryString(TestcaseBase): """ # 1. create a collection schema = cf.gen_string_pk_default_collection_schema() - collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix), schema=schema) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap( + cf.gen_unique_str(prefix), schema=schema) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() nb = 3000 @@ -2448,7 +2662,8 @@ class TestQueryString(TestcaseBase): assert collection_w.num_entities == nb string_exp = "varchar >= \"\"" - output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, + default_float_field_name, default_string_field_name] res, _ = collection_w.query(string_exp, output_fields=output_fields) assert len(res) == 1 @@ -2463,7 +2678,8 @@ class TestQueryString(TestcaseBase): expected: query successfully """ # 1. create a collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=False, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=False, is_index=False)[0:2] nb = 3000 df = cf.gen_default_list_data(nb) @@ -2472,11 +2688,13 @@ class TestQueryString(TestcaseBase): collection_w.insert(df) assert collection_w.num_entities == nb - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] collection_w.load() - output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, + default_float_field_name, default_string_field_name] expr = "varchar == \"\"" res, _ = collection_w.query(expr, output_fields=output_fields) @@ -2490,17 +2708,21 @@ class TestQueryString(TestcaseBase): method: create a collection and build diskann index expected: verify query result """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=True, is_index=False)[0:2] - collection_w.create_index(ct.default_float_vec_field_name, ct.default_diskann_index) + collection_w.create_index( + ct.default_float_vec_field_name, ct.default_diskann_index) assert collection_w.has_index()[0] collection_w.load() int_values = [0] term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) + check_vec = vectors[0].iloc[:, [0]][0:len( + int_values)].to_dict('records') + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ + exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L2) def test_query_with_create_diskann_with_string_pk(self): @@ -2512,7 +2734,8 @@ class TestQueryString(TestcaseBase): collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name, is_index=False)[0:2] - collection_w.create_index(ct.default_float_vec_field_name, ct.default_diskann_index) + collection_w.create_index( + ct.default_float_vec_field_name, ct.default_diskann_index) assert collection_w.has_index()[0] collection_w.load() res = vectors[0].iloc[:, 1:3].to_dict('records') @@ -2530,7 +2753,8 @@ class TestQueryString(TestcaseBase): expected: query successfully """ # 1. create a collection - collection_w, vectors = self.init_collection_general(prefix, insert_data=False, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general( + prefix, insert_data=False, is_index=False)[0:2] nb = 3000 df = cf.gen_default_list_data(nb) @@ -2539,10 +2763,12 @@ class TestQueryString(TestcaseBase): collection_w.insert(df) assert collection_w.num_entities == nb - collection_w.create_index(ct.default_float_vec_field_name, default_index_params) + collection_w.create_index( + ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] index_params = {} - collection_w.create_index(ct.default_int64_field_name, index_params=index_params) + collection_w.create_index( + ct.default_int64_field_name, index_params=index_params) collection_w.load() @@ -2570,9 +2796,11 @@ class TestQueryCount(TestcaseBase): 4. verify count expected: expected count """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), consistency_level=consistency_level) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), consistency_level=consistency_level) # load collection - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() df = cf.gen_default_dataframe_data() @@ -2597,9 +2825,11 @@ class TestQueryCount(TestcaseBase): method: expected: """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) # load collection - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert @@ -2618,7 +2848,8 @@ class TestQueryCount(TestcaseBase): method: count without loading expected: exception """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) collection_w.query(expr=default_term_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.err_res, check_items={"err_code": 1, @@ -2635,8 +2866,10 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert duplicate ids @@ -2670,7 +2903,8 @@ class TestQueryCount(TestcaseBase): """ half = ct.default_nb // 2 # insert [0, half) into partition_w, [half, nb) into _default - collection_w, p1, _, _ = self.insert_entities_into_two_partitions_in_half(half=half) + collection_w, p1, _, _ = self.insert_entities_into_two_partitions_in_half( + half=half) # query count p1, [p1, _default] for p_name in [p1.name, ct.default_partition_name]: @@ -2687,7 +2921,8 @@ class TestQueryCount(TestcaseBase): check_items={exp_res: [{count: 0}]} ) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], - partition_names=[p1.name, ct.default_partition_name], + partition_names=[ + p1.name, ct.default_partition_name], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: half}]} ) @@ -2718,7 +2953,8 @@ class TestQueryCount(TestcaseBase): """ # init partitions: _default and p1 p1 = "p1" - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) collection_w.create_partition(p1) df = cf.gen_default_dataframe_data() @@ -2726,7 +2962,8 @@ class TestQueryCount(TestcaseBase): collection_w.insert(df, partition_name=p1) # index and load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # count @@ -2740,7 +2977,8 @@ class TestQueryCount(TestcaseBase): collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], partition_names=[p1], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: ct.default_nb - delete_res.delete_count}]} + check_items={ + exp_res: [{count: ct.default_nb - delete_res.delete_count}]} ) @pytest.mark.tags(CaseLabel.L1) @@ -2755,7 +2993,8 @@ class TestQueryCount(TestcaseBase): """ tmp_nb = 100 # create -> insert -> index -> load -> count sealed - collection_w = self.init_collection_general(insert_data=True, nb=tmp_nb)[0] + collection_w = self.init_collection_general( + insert_data=True, nb=tmp_nb)[0] collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: tmp_nb}]} @@ -2778,8 +3017,10 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create -> index -> load - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # flush while count @@ -2792,7 +3033,7 @@ class TestQueryCount(TestcaseBase): "output_fields": [ct.default_count_output], "check_task": CheckTasks.check_query_results, "check_items": {exp_res: [{count: ct.default_nb}]} - }) + }) t_flush.start() t_count.start() @@ -2816,7 +3057,8 @@ class TestQueryCount(TestcaseBase): insert_res, _ = collection_w.insert(df) # delete growing and sealed ids -> count - collection_w.delete(f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") + collection_w.delete( + f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: tmp_nb}]} @@ -2827,7 +3069,8 @@ class TestQueryCount(TestcaseBase): collection_w.insert(df_same) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: ct.default_nb + tmp_nb}]} + check_items={ + exp_res: [{count: ct.default_nb + tmp_nb}]} ) @pytest.mark.tags(CaseLabel.L1) @@ -2839,7 +3082,8 @@ class TestQueryCount(TestcaseBase): 3. count expected: verify count """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), shards_num=1) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), shards_num=1) # init two segments tmp_nb = 100 @@ -2849,12 +3093,14 @@ class TestQueryCount(TestcaseBase): collection_w.insert(df) collection_w.flush() - collection_w.create_index(ct.default_float_vec_field_name, ct.default_index) + collection_w.create_index( + ct.default_float_vec_field_name, ct.default_index) collection_w.compact() collection_w.wait_for_compaction_completed() collection_w.load() - segment_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) + segment_info, _ = self.utility_wrap.get_query_segment_info( + collection_w.name) assert len(segment_info) == 1 # count after compact @@ -2872,8 +3118,10 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create -> index -> insert - collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix), shards_num=1) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap( + cf.gen_unique_str(prefix), shards_num=1) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) df = cf.gen_default_dataframe_data() insert_res, _ = collection_w.insert(df) @@ -2901,7 +3149,8 @@ class TestQueryCount(TestcaseBase): 2. compact while count expected: verify count """ - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), shards_num=1) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), shards_num=1) # init 2 segments tmp_nb = 100 @@ -2911,7 +3160,8 @@ class TestQueryCount(TestcaseBase): collection_w.flush() # compact while count - collection_w.create_index(ct.default_float_vec_field_name, ct.default_index) + collection_w.create_index( + ct.default_float_vec_field_name, ct.default_index) collection_w.load() t_compact = threading.Thread(target=collection_w.compact, args=()) @@ -2920,7 +3170,7 @@ class TestQueryCount(TestcaseBase): "output_fields": [ct.default_count_output], "check_task": CheckTasks.check_query_results, "check_items": {exp_res: [{count: tmp_nb * 10}]} - }) + }) t_compact.start() t_count.start() @@ -2964,7 +3214,8 @@ class TestQueryCount(TestcaseBase): # count with limit collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], limit=10, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "count entities with pagination is not allowed"} + check_items={ + ct.err_code: 1, ct.err_msg: "count entities with pagination is not allowed"} ) # count with pagination params collection_w.query(default_expr, output_fields=[ct.default_count_output], offset=10, limit=10, @@ -2992,7 +3243,8 @@ class TestQueryCount(TestcaseBase): # new insert partitions and count p_name = cf.gen_unique_str("p_alias") collection_w_alias.create_partition(p_name) - collection_w_alias.insert(cf.gen_default_dataframe_data(start=ct.default_nb), partition_name=p_name) + collection_w_alias.insert(cf.gen_default_dataframe_data( + start=ct.default_nb), partition_name=p_name) collection_w_alias.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: ct.default_nb * 2}]}) @@ -3001,7 +3253,8 @@ class TestQueryCount(TestcaseBase): collection_w_alias.drop_partition(p_name, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "cannot drop the collection via alias"}) - self.partition_wrap.init_partition(collection_w_alias.collection, p_name) + self.partition_wrap.init_partition( + collection_w_alias.collection, p_name) self.partition_wrap.release() collection_w_alias.drop_partition(p_name) @@ -3012,7 +3265,8 @@ class TestQueryCount(TestcaseBase): check_items={exp_res: [{count: ct.default_nb}]}) # alias delete and count - collection_w_alias.delete(f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") + collection_w_alias.delete( + f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") collection_w_alias.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: 0}]}) @@ -3036,7 +3290,8 @@ class TestQueryCount(TestcaseBase): if is_growing: # create -> index -> load -> insert -> delete collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.insert(cf.gen_default_dataframe_data()) @@ -3052,7 +3307,8 @@ class TestQueryCount(TestcaseBase): single_expr = f'{ct.default_int64_field_name} in [0]' collection_w.delete(single_expr) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # upsert deleted id @@ -3088,12 +3344,14 @@ class TestQueryCount(TestcaseBase): """ # init collection and insert same ids tmp_nb = 100 - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=tmp_nb) df[ct.default_int64_field_name] = 0 collection_w.insert(df) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # upsert id and count @@ -3111,7 +3369,8 @@ class TestQueryCount(TestcaseBase): check_items={exp_res: [{count: tmp_nb - delete_res.delete_count}]}) # upsert deleted id and count - df_deleted = cf.gen_default_dataframe_data(nb=delete_res.delete_count, start=0) + df_deleted = cf.gen_default_dataframe_data( + nb=delete_res.delete_count, start=0) collection_w.upsert(df_deleted) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, @@ -3144,8 +3403,10 @@ class TestQueryCount(TestcaseBase): expected: verify count 0 """ # create -> index -> load - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix)) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert @@ -3163,7 +3424,8 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create -> insert -> index -> load - collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general( + insert_data=True)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -3191,7 +3453,8 @@ class TestQueryCount(TestcaseBase): """ # create -> insert -> index -> load collection_w, _vectors, _, insert_ids = \ - self.init_collection_general(insert_data=True, is_all_data_type=True)[0:4] + self.init_collection_general( + insert_data=True, is_all_data_type=True)[0:4] # filter result with expression in collection filter_ids = [] @@ -3220,7 +3483,8 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create -> insert -> index -> load - collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general( + insert_data=True)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -3246,7 +3510,8 @@ class TestQueryCount(TestcaseBase): expected: verify count """ # create -> insert -> index -> load - collection_w = self.init_collection_general(insert_data=True, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + insert_data=True, is_all_data_type=True)[0] # count with expr expression = "int64 >= 0 && int32 >= 1999 && int16 >= 0 && int8 >= 0 && float <= 1999.0 && double >= 0" @@ -3265,12 +3530,14 @@ class TestQueryCount(TestcaseBase): # create -> insert -> index -> load fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field()] - schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") + schema = cf.gen_collection_schema( + fields=fields, primary_field="int64_1") collection_w = self.init_collection_wrap(schema=schema) nb, res = 10, 0 int_values = [random.randint(0, nb) for _ in range(nb)] - data = [[i for i in range(nb)], int_values, cf.gen_vectors(nb, ct.default_dim)] + data = [[i for i in range(nb)], int_values, + cf.gen_vectors(nb, ct.default_dim)] collection_w.insert(data) collection_w.create_index(ct.default_float_vec_field_name) collection_w.load() @@ -3302,8 +3569,10 @@ class TestQueryIterator(TestcaseBase): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] - collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] + collection_w.create_index( + ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3338,8 +3607,10 @@ class TestQueryIterator(TestcaseBase): """ # 1. initialize with data batch_size = 300 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] - collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] + collection_w.create_index( + ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3359,8 +3630,10 @@ class TestQueryIterator(TestcaseBase): """ # 1. initialize with data offset = 500 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] - collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] + collection_w.create_index( + ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3402,7 +3675,8 @@ class TestQueryIterator(TestcaseBase): # 2. search iterator expr = "int64 >= 0" error = {"err_code": 1, "err_msg": "batch size cannot be less than zero"} - collection_w.query_iterator(batch_size=-1, expr=expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query_iterator( + batch_size=-1, expr=expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) @pytest.mark.parametrize("batch_size", [100, 500]) @@ -3414,7 +3688,8 @@ class TestQueryIterator(TestcaseBase): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, auto_id=auto_id)[0:4] # 2. query with limit collection_w.query_iterator(batch_size=batch_size, @@ -3433,7 +3708,8 @@ class TestQueryIterator(TestcaseBase): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, random_primary_key=True)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, random_primary_key=True)[0:4] # 3. query with empty expr and check the result exp_ids = sorted(insert_ids) diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 33a16d7c2c..398565563f 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -1,3 +1,15 @@ +from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY +from common.constants import * +from utils.util_pymilvus import * +from common.common_type import CaseLabel, CheckTasks +from common import common_type as ct +from common import common_func as cf +from utils.util_log import test_log as log +from base.client_base import TestcaseBase +import heapq +from time import sleep +from decimal import Decimal, getcontext +import decimal import multiprocessing import numbers import random @@ -6,19 +18,7 @@ import threading import pytest import pandas as pd pd.set_option("expand_frame_repr", False) -import decimal -from decimal import Decimal, getcontext -from time import sleep -import heapq -from base.client_base import TestcaseBase -from utils.util_log import test_log as log -from common import common_func as cf -from common import common_type as ct -from common.common_type import CaseLabel, CheckTasks -from utils.util_pymilvus import * -from common.constants import * -from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY prefix = "search_collection" search_num = 10 @@ -45,8 +45,10 @@ default_float_field_name = ct.default_float_field_name default_bool_field_name = ct.default_bool_field_name default_string_field_name = ct.default_string_field_name default_json_field_name = ct.default_json_field_name -default_index_params = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} -vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] +default_index_params = {"index_type": "IVF_SQ8", + "metric_type": "COSINE", "params": {"nlist": 64}} +vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] range_search_supported_index = ct.all_index_types[:6] range_search_supported_index_params = ct.default_index_params[:6] uid = "test_search" @@ -58,7 +60,8 @@ search_param = {"nprobe": 1} entity = gen_entities(1, is_normal=True) entities = gen_entities(default_nb, is_normal=True) raw_vectors, binary_entities = gen_binary_entities(default_nb) -default_query, _ = gen_search_vectors_params(field_name, entities, default_top_k, nq) +default_query, _ = gen_search_vectors_params( + field_name, entities, default_top_k, nq) index_name1 = cf.gen_unique_str("float") index_name2 = cf.gen_unique_str("varhar") half_nb = ct.default_nb // 2 @@ -196,7 +199,7 @@ class TestCollectionSearchInvalid(TestcaseBase): default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, check_items={"err_code": 1, - "err_msg": "collection %s doesn't exist!" % collection_w.name}) + "err_msg": "collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_search_param_missing(self): @@ -247,7 +250,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 2. search with invalid dim log.info("test_search_param_invalid_dim: searching with invalid dim") wrong_dim = 129 - vectors = [[random.random() for _ in range(wrong_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(wrong_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, @@ -304,9 +308,11 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search with invalid metric_type - log.info("test_search_param_invalid_metric_type: searching with invalid metric_type") + log.info( + "test_search_param_invalid_metric_type: searching with invalid metric_type") invalid_metric = get_invalid_metric_type - search_params = {"metric_type": invalid_metric, "params": {"nprobe": 10}} + search_params = {"metric_type": invalid_metric, + "params": {"nprobe": 10}} collection_w.search(vectors[:default_nq], default_search_field, search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, @@ -329,7 +335,8 @@ class TestCollectionSearchInvalid(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, 5000, is_index=False)[0:4] # 2. create index and load - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search @@ -337,7 +344,8 @@ class TestCollectionSearchInvalid(TestcaseBase): message = "Search params check failed" for invalid_search_param in invalid_search_params: if index == invalid_search_param["index_type"]: - search_params = {"metric_type": "L2", "params": invalid_search_param["search_params"]} + search_params = {"metric_type": "L2", + "params": invalid_search_param["search_params"]} collection_w.search(vectors[:default_nq], default_search_field, search_params, default_limit, default_search_exp, @@ -355,13 +363,16 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: raise exception and report the error """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, True, 3000, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 3000, is_index=False)[0] # 2. create annoy index and load - index_annoy = {"index_type": "ANNOY", "params": {"n_trees": 512}, "metric_type": "L2"} + index_annoy = {"index_type": "ANNOY", "params": { + "n_trees": 512}, "metric_type": "L2"} collection_w.create_index("float_vector", index_annoy) collection_w.load() # 3. search - annoy_search_param = {"index_type": "ANNOY", "search_params": {"search_k": search_k}} + annoy_search_param = {"index_type": "ANNOY", + "search_params": {"search_k": search_k}} collection_w.search(vectors[:default_nq], default_search_field, annoy_search_param, default_limit, default_search_exp, @@ -444,7 +455,8 @@ class TestCollectionSearchInvalid(TestcaseBase): dim = 1 fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field(dim=dim)] - schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") + schema = cf.gen_collection_schema( + fields=fields, primary_field="int64_1") collection_w = self.init_collection_wrap(schema=schema) # 2. insert data @@ -454,11 +466,14 @@ class TestCollectionSearchInvalid(TestcaseBase): collection_w.insert(dataframe) # 3. search with expression - log.info("test_search_with_expression: searching with expression: %s" % expression) - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + log.info( + "test_search_with_expression: searching with expression: %s" % expression) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() expression = expression.replace("&&", "and").replace("||", "or") - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, check_task=CheckTasks.err_res, @@ -494,7 +509,8 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: raise exception and report the error """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general( + prefix, is_all_data_type=True)[0] # 2 search with invalid bool expr invalid_search_expr_bool = f"{default_bool_field_name} == {get_invalid_expr_bool_value}" log.info("test_search_param_invalid_expr_bool: searching with " @@ -513,8 +529,10 @@ class TestCollectionSearchInvalid(TestcaseBase): method: test search invalid bool expected: searched failed """ - collection_w = self.init_collection_general(prefix, True, is_all_data_type=True)[0] - log.info("test_search_with_expression: searching with expression: %s" % expression) + collection_w = self.init_collection_general( + prefix, True, is_all_data_type=True)[0] + log.info( + "test_search_with_expression: searching with expression: %s" % expression) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expression, check_task=CheckTasks.err_res, @@ -531,11 +549,14 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: searched failed """ collection_w = self.init_collection_general(prefix, is_index=False)[0] - index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() - log.info("test_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info( + "test_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expression, check_task=CheckTasks.err_res, @@ -554,7 +575,8 @@ class TestCollectionSearchInvalid(TestcaseBase): collection_w = self.init_collection_general(prefix)[0] # 2. search the invalid partition partition_name = get_invalid_partition - err_msg = "`partition_name_array` value {} is illegal".format(partition_name) + err_msg = "`partition_name_array` value {} is illegal".format( + partition_name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, partition_name, check_task=CheckTasks.err_res, @@ -571,7 +593,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search - log.info("test_search_with_output_fields_invalid_type: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_fields_invalid_type: Searching collection %s" % + collection_w.name) output_fields = get_invalid_output_fields err_msg = "`output_fields` value {} is illegal".format(output_fields) collection_w.search(vectors[:default_nq], default_search_field, @@ -614,8 +637,10 @@ class TestCollectionSearchInvalid(TestcaseBase): """ # 1. initialize with data partition_num = 1 - collection_w = self.init_collection_general(prefix, True, 10, partition_num, is_index=False)[0] - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_general( + prefix, True, 10, partition_num, is_index=False)[0] + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) par = collection_w.partitions par_name = par[partition_num].name par[partition_num].load() @@ -655,7 +680,8 @@ class TestCollectionSearchInvalid(TestcaseBase): check_items={"err_code": 1, "err_msg": err_msg}) # 3. search collection without data after load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -689,7 +715,8 @@ class TestCollectionSearchInvalid(TestcaseBase): collection_w = self.init_collection_general(prefix, partition_num=1)[0] par = collection_w.partitions # 2. search collection without data after load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -717,14 +744,16 @@ class TestCollectionSearchInvalid(TestcaseBase): """ # 1. initialize with data partition_num = 1 - collection_w = self.init_collection_general(prefix, True, 1000, partition_num, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 1000, partition_num, is_index=False)[0] # 2. delete partitions log.info("test_search_partition_deleted: deleting a partition") par = collection_w.partitions deleted_par_name = par[partition_num].name collection_w.drop_partition(deleted_par_name) log.info("test_search_partition_deleted: deleted a partition") - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 3. search after delete partitions log.info("test_search_partition_deleted: searching deleted partition") @@ -753,13 +782,17 @@ class TestCollectionSearchInvalid(TestcaseBase): if params.get("m"): if (default_dim % params["m"]) != 0: params["m"] = default_dim // 4 - log.info("test_search_different_index_invalid_params: Creating index-%s" % index) - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + log.info( + "test_search_different_index_invalid_params: Creating index-%s" % index) + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) - log.info("test_search_different_index_invalid_params: Created index-%s" % index) + log.info( + "test_search_different_index_invalid_params: Created index-%s" % index) collection_w.load() # 3. search - log.info("test_search_different_index_invalid_params: Searching after creating index-%s" % index) + log.info( + "test_search_different_index_invalid_params: Searching after creating index-%s" % index) search_params = cf.gen_invalid_search_param(index) collection_w.search(vectors, default_search_field, search_params[0], default_limit, @@ -778,12 +811,14 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix, is_index=False)[0] # 2. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) # 3. search the non exist partition partition_name = "search_non_exist" collection_w.search(vectors[:default_nq], default_search_field, default_search_params, - default_limit, default_search_exp, [partition_name], + default_limit, default_search_exp, [ + partition_name], check_task=CheckTasks.err_res, check_items={"err_code": 1, "err_msg": "PartitonName: %s not found" % partition_name}) @@ -799,7 +834,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # initialize with data collection_w = self.init_collection_general(prefix, True)[0] # search - vectors = [[random.random() for _ in range(default_dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -817,9 +853,11 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: raise exception and report the error """ # 1. initialize with binary data - collection_w = self.init_collection_general(prefix, True, is_binary=True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, is_binary=True, is_index=False)[0] # 2. create index - default_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": "BIN_IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) # 3. search with exception binary_vectors = cf.gen_binary_vectors(3000, default_dim)[1] @@ -838,9 +876,11 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: raise exception and report error """ # 1. initialize with binary data - collection_w = self.init_collection_general(prefix, True, is_binary=True)[0] + collection_w = self.init_collection_general( + prefix, True, is_binary=True)[0] # 2. search and assert - query_raw_vector, binary_vectors = cf.gen_binary_vectors(2, default_dim) + query_raw_vector, binary_vectors = cf.gen_binary_vectors( + 2, default_dim) search_params = {"metric_type": "L2", "params": {"nprobe": 10}} collection_w.search(binary_vectors[:default_nq], "binary_vector", search_params, default_limit, "int64 >= 0", @@ -858,7 +898,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search - log.info("test_search_with_output_fields_not_exist: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_fields_not_exist: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=["int63"], @@ -879,7 +920,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix, True)[0] # 2. search - log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name) + log.info("test_search_output_field_vector: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=output_fields) @@ -895,10 +937,12 @@ class TestCollectionSearchInvalid(TestcaseBase): expected: raise exception and report the error """ # 1. create a collection and insert data - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] # 2. create an index which doesn't output vectors - default_index = {"index_type": index, "params": param, "metric_type": "L2"} + default_index = {"index_type": index, + "params": param, "metric_type": "L2"} collection_w.create_index(field_name, default_index) # 3. load and search @@ -921,7 +965,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search - log.info("test_search_output_field_invalid_wildcard: Searching collection %s" % collection_w.name) + log.info("test_search_output_field_invalid_wildcard: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=output_fields, @@ -949,8 +994,10 @@ class TestCollectionSearchInvalid(TestcaseBase): collection_w.insert(data) # 3. search with param ignore_growing=True - search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "ignore_growing": ignore_growing} - vector = [[random.random() for _ in range(default_dim)] for _ in range(nq)] + search_params = {"metric_type": "L2", "params": { + "nprobe": 10}, "ignore_growing": ignore_growing} + vector = [[random.random() for _ in range(default_dim)] + for _ in range(nq)] collection_w.search(vector[:default_nq], default_search_field, search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, @@ -967,7 +1014,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search with invalid guaranteeetimestamp - log.info("test_search_param_invalid_guarantee_timestamp: searching with invalid guarantee timestamp") + log.info( + "test_search_param_invalid_guarantee_timestamp: searching with invalid guarantee timestamp") invalid_guarantee_time = get_invalid_guarantee_timestamp collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -988,7 +1036,8 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search - log.info("test_search_invalid_round_decimal: Searching collection %s" % collection_w.name) + log.info("test_search_invalid_round_decimal: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, round_decimal=round_decimal, @@ -1006,9 +1055,11 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. range search - log.info("test_range_search_invalid_radius: Range searching collection %s" % collection_w.name) + log.info("test_range_search_invalid_radius: Range searching collection %s" % + collection_w.name) radius = get_invalid_range_search_paras - range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": radius, "range_filter": 0}} + range_search_params = {"metric_type": "L2", "params": { + "nprobe": 10, "radius": radius, "range_filter": 0}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -1026,9 +1077,11 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. range search - log.info("test_range_search_invalid_range_filter: Range searching collection %s" % collection_w.name) + log.info("test_range_search_invalid_range_filter: Range searching collection %s" % + collection_w.name) range_filter = get_invalid_range_search_paras - range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": 1, "range_filter": range_filter}} + range_search_params = {"metric_type": "L2", "params": { + "nprobe": 10, "radius": 1, "range_filter": range_filter}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -1046,8 +1099,10 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. range search - log.info("test_range_search_invalid_radius_range_filter_L2: Range searching collection %s" % collection_w.name) - range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": 1, "range_filter": 10}} + log.info("test_range_search_invalid_radius_range_filter_L2: Range searching collection %s" % + collection_w.name) + range_search_params = {"metric_type": "L2", "params": { + "nprobe": 10, "radius": 1, "range_filter": 10}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -1065,8 +1120,10 @@ class TestCollectionSearchInvalid(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. range search - log.info("test_range_search_invalid_radius_range_filter_IP: Range searching collection %s" % collection_w.name) - range_search_params = {"metric_type": "IP", "params": {"nprobe": 10, "radius": 10, "range_filter": 1}} + log.info("test_range_search_invalid_radius_range_filter_IP: Range searching collection %s" % + collection_w.name) + range_search_params = {"metric_type": "IP", "params": { + "nprobe": 10, "radius": 10, "range_filter": 1}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -1090,12 +1147,14 @@ class TestCollectionSearchInvalid(TestcaseBase): partition_num=1, dim=default_dim, is_index=False)[0:5] # 2. create index and load - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. range search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 1000 search_param["params"]["range_filter"] = 0 @@ -1121,12 +1180,14 @@ class TestCollectionSearchInvalid(TestcaseBase): partition_num=1, dim=default_dim, is_index=False)[0:5] # 2. create index and load - default_index = {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": metric} + default_index = {"index_type": "BIN_FLAT", + "params": {"nlist": 128}, "metric_type": metric} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. range search search_params = cf.gen_search_param("BIN_FLAT", metric_type=metric) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 1000 search_param["params"]["range_filter"] = 0 @@ -1157,7 +1218,8 @@ class TestCollectionSearchInvalid(TestcaseBase): = self.init_collection_general(prefix, True, default_nb, is_binary=True, dim=dim, is_index=False)[0:5] # 2. create index - default_index = {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": metric} + default_index = {"index_type": "BIN_FLAT", + "params": {"nlist": 128}, "metric_type": metric} collection_w.create_index("binary_vector", default_index, check_task=CheckTasks.err_res, check_items={"err_code": 1, @@ -1179,16 +1241,20 @@ class TestCollectionSearchInvalid(TestcaseBase): enable_dynamic_field=True)[0] # create index - index_params_one = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} - collection_w.create_index(ct.default_float_vec_field_name, index_params_one, index_name=index_name1) + index_params_one = {"index_type": "IVF_SQ8", + "metric_type": "COSINE", "params": {"nlist": 64}} + collection_w.create_index( + ct.default_float_vec_field_name, index_params_one, index_name=index_name1) index_params_two = {} - collection_w.create_index(ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) + collection_w.create_index( + ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) assert collection_w.has_index(index_name=index_name2) collection_w.load() # delete entity expr = 'float >= int64' # search with id 0 vectors - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expr, @@ -1280,7 +1346,8 @@ class TestCollectionSearch(TestcaseBase): auto_id = True # 1. initialize with data collection_w, _, _, insert_ids, time_stamp = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_flush=True)[0:5] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_flush=True)[0:5] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] search_params = {"params": {"nprobe": 10}} # 2. search after insert @@ -1305,8 +1372,9 @@ class TestCollectionSearch(TestcaseBase): auto_id = True # 1. initialize with data collection_w, _, _, insert_ids, time_stamp = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_flush=True)[0:5] - # 2. search after insert + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_flush=True)[0:5] + # 2. search after insert vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], "", default_search_params, default_limit, @@ -1336,7 +1404,8 @@ class TestCollectionSearch(TestcaseBase): else: vectors = np.array(_vectors[0]).tolist() vectors = [vectors[i][-1] for i in range(nq)] - log.info("test_search_with_hit_vectors: searching collection %s" % collection_w.name) + log.info("test_search_with_hit_vectors: searching collection %s" % + collection_w.name) search_res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -1359,10 +1428,13 @@ class TestCollectionSearch(TestcaseBase): # 1. initialize collection with random primary key collection_w, _vectors, _, insert_ids, time_stamp = \ - self.init_collection_general(prefix, True, 10, random_primary_key=random_primary_key)[0:5] + self.init_collection_general( + prefix, True, 10, random_primary_key=random_primary_key)[0:5] # 2. search - log.info("test_search_random_primary_key: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_random_primary_key: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -1398,7 +1470,8 @@ class TestCollectionSearch(TestcaseBase): insert_res, _ = collection_w.insert(insert_data[0]) insert_ids.extend(insert_res.primary_keys) # search - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -1424,7 +1497,8 @@ class TestCollectionSearch(TestcaseBase): expected: search successfully """ # initialize with data - collection_w, insert_data, _, insert_ids = self.init_collection_general(prefix, True)[0:4] + collection_w, insert_data, _, insert_ids = self.init_collection_general(prefix, True)[ + 0:4] # search collection_w.search(vectors[:nq], default_search_field, search_params, default_limit, @@ -1444,7 +1518,8 @@ class TestCollectionSearch(TestcaseBase): """ # 1. create a collection, insert data and flush nb = 10 - collection_w = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, nb, dim=dim, is_index=False)[0] # 2. insert data and flush again for two segments data = cf.gen_default_dataframe_data(nb=nb, dim=dim, start=nb) @@ -1452,11 +1527,13 @@ class TestCollectionSearch(TestcaseBase): collection_w.flush() # 3. create index and load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. get inserted original data - inserted_vectors = collection_w.query(expr="int64 >= 0", output_fields=[ct.default_float_vec_field_name]) + inserted_vectors = collection_w.query(expr="int64 >= 0", output_fields=[ + ct.default_float_vec_field_name]) original_vectors = [] for single in inserted_vectors[0]: single_vector = single[ct.default_float_vec_field_name] @@ -1477,10 +1554,10 @@ class TestCollectionSearch(TestcaseBase): default_search_params, limit, check_task=CheckTasks.check_search_results, check_items={ - "nq": 1, - "limit": limit, - "ids": list(distances_index_max) - }) + "nq": 1, + "limit": limit, + "ids": list(distances_index_max) + }) @pytest.mark.tags(CaseLabel.L1) def test_search_with_empty_vectors(self, dim, auto_id, _async, enable_dynamic_field): @@ -1540,12 +1617,15 @@ class TestCollectionSearch(TestcaseBase): enable_dynamic_field=enable_dynamic_field)[0:4] # 2. rename collection new_collection_name = cf.gen_unique_str(prefix + "new") - self.utility_wrap.rename_collection(collection_w.name, new_collection_name) + self.utility_wrap.rename_collection( + collection_w.name, new_collection_name) collection_w = self.init_collection_general(auto_id=auto_id, dim=dim, name=new_collection_name, enable_dynamic_field=enable_dynamic_field)[0] # 3. search - log.info("test_search_normal_default_params: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info("test_search_normal_default_params: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_params, default_limit, default_search_exp, _async=_async, @@ -1574,7 +1654,8 @@ class TestCollectionSearch(TestcaseBase): auto_id=auto_id, dim=dim)[0:4] # 2. search all the partitions before partition deletion vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] - log.info("test_search_before_after_delete: searching before deleting partitions") + log.info( + "test_search_before_after_delete: searching before deleting partitions") collection_w.search(vectors[:nq], default_search_field, default_search_params, limit, default_search_exp, _async=_async, @@ -1591,10 +1672,12 @@ class TestCollectionSearch(TestcaseBase): entity_num = nb - deleted_entity_num collection_w.drop_partition(par[partition_num].name) log.info("test_search_before_after_delete: deleted a partition") - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. search non-deleted part after delete partitions - log.info("test_search_before_after_delete: searching after deleting partitions") + log.info( + "test_search_before_after_delete: searching after deleting partitions") collection_w.search(vectors[:nq], default_search_field, default_search_params, limit, default_search_exp, _async=_async, @@ -1618,14 +1701,16 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb, 1, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. release collection - log.info("test_search_collection_after_release_load: releasing collection %s" % collection_w.name) + log.info("test_search_collection_after_release_load: releasing collection %s" % + collection_w.name) collection_w.release() - log.info("test_search_collection_after_release_load: released collection %s" % collection_w.name) + log.info("test_search_collection_after_release_load: released collection %s" % + collection_w.name) # 3. Search the pre-released collection after load - log.info("test_search_collection_after_release_load: loading collection %s" % collection_w.name) + log.info("test_search_collection_after_release_load: loading collection %s" % + collection_w.name) collection_w.load() log.info("test_search_collection_after_release_load: searching after load") vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] @@ -1653,7 +1738,8 @@ class TestCollectionSearch(TestcaseBase): insert_ids = cf.insert_data(collection_w, nb, auto_id=auto_id, dim=dim, enable_dynamic_field=enable_dynamic_field)[3] # 3. load data - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. flush and load collection_w.num_entities @@ -1733,9 +1819,11 @@ class TestCollectionSearch(TestcaseBase): collection_w.insert(dataframe) # 2. load and search - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, _async=_async, @@ -1757,17 +1845,20 @@ class TestCollectionSearch(TestcaseBase): """ # 1. connect, create collection and insert data self._connect() - collection_w = self.init_collection_general(prefix, False, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, False, dim=dim, is_index=False)[0] dataframe = cf.gen_default_dataframe_data(dim=dim, start=-1500) collection_w.insert(dataframe) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) # 3. load and search collection_w.load() - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, _async=_async, @@ -1789,8 +1880,10 @@ class TestCollectionSearch(TestcaseBase): dim=max_dim)[0:4] # 2. search nq = 2 - log.info("test_search_max_dim: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(max_dim)] for _ in range(nq)] + log.info("test_search_max_dim: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(max_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, nq, default_search_exp, _async=_async, @@ -1814,8 +1907,10 @@ class TestCollectionSearch(TestcaseBase): enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search nq = 2 - log.info("test_search_min_dim: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(min_dim)] for _ in range(nq)] + log.info("test_search_min_dim: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(min_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, nq, default_search_exp, _async=_async, @@ -1833,9 +1928,12 @@ class TestCollectionSearch(TestcaseBase): method: create collection, insert, load and search with different nq ∈ [1, 16384] expected: search successfully with different nq """ - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, nb=20000)[0:4] - log.info("test_search_max_nq: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(nq)] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, nb=20000)[0:4] + log.info("test_search_max_nq: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, check_task=CheckTasks.check_search_results, @@ -1854,19 +1952,24 @@ class TestCollectionSearch(TestcaseBase): self._connect() # 1. create collection name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=name, shards_num=shards_num) + collection_w = self.init_collection_wrap( + name=name, shards_num=shards_num) # 2. rename collection new_collection_name = cf.gen_unique_str(prefix + "new") - self.utility_wrap.rename_collection(collection_w.name, new_collection_name) - collection_w = self.init_collection_wrap(name=new_collection_name, shards_num=shards_num) + self.utility_wrap.rename_collection( + collection_w.name, new_collection_name) + collection_w = self.init_collection_wrap( + name=new_collection_name, shards_num=shards_num) # 3. insert dataframe = cf.gen_default_dataframe_data() collection_w.insert(dataframe) # 4. create index and load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 5. search - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -1890,14 +1993,15 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] HNSW_index_params = {"M": M, "efConstruction": efConstruction} - HNSW_index = {"index_type": "HNSW", "params": HNSW_index_params, "metric_type": "L2"} + HNSW_index = {"index_type": "HNSW", + "params": HNSW_index_params, "metric_type": "L2"} collection_w.create_index("float_vector", HNSW_index) collection_w.load() search_param = {"metric_type": "L2", "params": {"ef": 32768}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, _async=_async, @@ -1922,14 +2026,18 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] - HNSW_index_params = {"M": M, "efConstruction": efConstruction, "nlist": 100} # nlist is of no use - HNSW_index = {"index_type": "HNSW", "params": HNSW_index_params, "metric_type": "L2"} + enable_dynamic_field=enable_dynamic_field)[0:5] + # nlist is of no use + HNSW_index_params = { + "M": M, "efConstruction": efConstruction, "nlist": 100} + HNSW_index = {"index_type": "HNSW", + "params": HNSW_index_params, "metric_type": "L2"} collection_w.create_index("float_vector", HNSW_index) collection_w.load() - search_param = {"metric_type": "L2", "params": {"ef": 32768, "nprobe": 10}} # nprobe is of no use - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + search_param = {"metric_type": "L2", "params": { + "ef": 32768, "nprobe": 10}} # nprobe is of no use + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, _async=_async, @@ -1956,14 +2064,15 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] HNSW_index_params = {"M": M, "efConstruction": efConstruction} - HNSW_index = {"index_type": "HNSW", "params": HNSW_index_params, "metric_type": "L2"} + HNSW_index = {"index_type": "HNSW", + "params": HNSW_index_params, "metric_type": "L2"} collection_w.create_index("float_vector", HNSW_index) collection_w.load() search_param = {"metric_type": "L2", "params": {"ef": ef}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, limit, default_search_exp, _async=_async, @@ -1989,8 +2098,7 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. create index and load if params.get("m"): if (dim % params["m"]) != 0: @@ -1998,12 +2106,14 @@ class TestCollectionSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, + "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "COSINE") - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) limit = default_limit @@ -2038,8 +2148,7 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. create index and load if params.get("m"): if (dim % params["m"]) != 0: @@ -2047,12 +2156,14 @@ class TestCollectionSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2074,7 +2185,8 @@ class TestCollectionSearch(TestcaseBase): expected: search successfully """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, True, auto_id=auto_id, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, auto_id=auto_id, is_index=False)[0] # 2. create index and load collection_w.create_index("float_vector", {}) collection_w.load() @@ -2087,7 +2199,7 @@ class TestCollectionSearch(TestcaseBase): check_items={"nq": default_nq, "limit": default_limit, "_async": _async}) - + @pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", @@ -2109,12 +2221,14 @@ class TestCollectionSearch(TestcaseBase): params["m"] = min_dim if params.get("PQM"): params["PQM"] = min_dim - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(min_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(min_dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2125,7 +2239,7 @@ class TestCollectionSearch(TestcaseBase): "ids": insert_ids, "limit": default_limit, "_async": _async}) - + @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", zip(ct.all_index_types[8:10], @@ -2146,12 +2260,14 @@ class TestCollectionSearch(TestcaseBase): params["m"] = min_dim if params.get("PQM"): params["PQM"] = min_dim - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(min_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(min_dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2168,7 +2284,7 @@ class TestCollectionSearch(TestcaseBase): @pytest.mark.parametrize("index, params", zip(ct.all_index_types[:6], ct.default_index_params[:6])) - def test_search_after_index_different_metric_type(self, dim, index, params, auto_id, _async, + def test_search_after_index_different_metric_type(self, dim, index, params, auto_id, _async, enable_dynamic_field, metric_type): """ target: test search with different metric type @@ -2180,8 +2296,7 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. get vectors that inserted into collection original_vectors = [] if enable_dynamic_field: @@ -2201,14 +2316,18 @@ class TestCollectionSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info("test_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, "params": params, "metric_type": metric_type} + log.info( + "test_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, + "params": params, "metric_type": metric_type} collection_w.create_index("float_vector", default_index) - log.info("test_search_after_index_different_metric_type: Created index-%s" % index) + log.info( + "test_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 4. search search_params = cf.gen_search_param(index, metric_type) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) limit = default_limit @@ -2247,8 +2366,7 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. get vectors that inserted into collection original_vectors = [] if enable_dynamic_field: @@ -2267,14 +2385,17 @@ class TestCollectionSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info("test_search_after_release_recreate_index: Creating index-%s" % index) - default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} + log.info( + "test_search_after_release_recreate_index: Creating index-%s" % index) + default_index = {"index_type": index, + "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) log.info("test_search_after_release_recreate_index: Created index-%s" % index) collection_w.load() # 4. search search_params = cf.gen_search_param(index, "COSINE") - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2283,7 +2404,8 @@ class TestCollectionSearch(TestcaseBase): # 5. re-create index collection_w.release() collection_w.drop_index() - default_index = {"index_type": index, "params": params, "metric_type": metric_type} + default_index = {"index_type": index, + "params": params, "metric_type": metric_type} collection_w.create_index("float_vector", default_index) collection_w.load() for search_param in search_params: @@ -2315,8 +2437,7 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. create different index if params.get("m"): if (dim % params["m"]) != 0: @@ -2324,14 +2445,18 @@ class TestCollectionSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info("test_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, "params": params, "metric_type": "IP"} + log.info( + "test_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, + "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) - log.info("test_search_after_index_different_metric_type: Created index-%s" % index) + log.info( + "test_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "IP") - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2354,12 +2479,12 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, nb, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search for multiple times vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] for i in range(search_num): - log.info("test_search_collection_multiple_times: searching round %d" % (i + 1)) + log.info( + "test_search_collection_multiple_times: searching round %d" % (i + 1)) collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -2381,13 +2506,14 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. search - log.info("test_search_sync_async_multiple_times: searching collection %s" % collection_w.name) + log.info("test_search_sync_async_multiple_times: searching collection %s" % + collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] for i in range(search_num): - log.info("test_search_sync_async_multiple_times: searching round %d" % (i + 1)) + log.info( + "test_search_sync_async_multiple_times: searching round %d" % (i + 1)) for _async in [False, True]: collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, @@ -2414,15 +2540,19 @@ class TestCollectionSearch(TestcaseBase): """ vec_fields = [cf.gen_float_vec_field(name="test_vector1")] schema = cf.gen_schema_multi_vector_fields(vec_fields) - collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_vec_fields(vec_fields=vec_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - _index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} - res, ch = collection_w.create_index(field_name="test_vector1", index_params=_index) + _index = {"index_type": "IVF_FLAT", "params": { + "nlist": 128}, "metric_type": "L2"} + res, ch = collection_w.create_index( + field_name="test_vector1", index_params=_index) assert ch is True collection_w.load() - vectors = [[random.random() for _ in range(default_dim)] for _ in range(2)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(2)] search_params = {"metric_type": "L2", "params": {"nprobe": 16}} res_1, _ = collection_w.search(data=vectors, anns_field="test_vector1", param=search_params, limit=1) @@ -2439,15 +2569,16 @@ class TestCollectionSearch(TestcaseBase): partition_num=1, auto_id=auto_id, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search in one partition - log.info("test_search_index_one_partition: searching (1000 entities) through one partition") + log.info( + "test_search_index_one_partition: searching (1000 entities) through one partition") limit = 1000 par = collection_w.partitions if limit > par[1].num_entities: @@ -2479,11 +2610,13 @@ class TestCollectionSearch(TestcaseBase): is_index=False)[0:4] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search through partitions - log.info("test_search_index_partitions: searching (1000 entities) through partitions") + log.info( + "test_search_index_partitions: searching (1000 entities) through partitions") par = collection_w.partitions log.info("test_search_index_partitions: partitions: %s" % par) search_params = {"metric_type": "L2", "params": {"nprobe": 64}} @@ -2512,12 +2645,12 @@ class TestCollectionSearch(TestcaseBase): auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create index nlist = 128 - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}, "metric_type": "COSINE"} + default_index = {"index_type": "IVF_FLAT", "params": { + "nlist": nlist}, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search through partitions @@ -2552,11 +2685,13 @@ class TestCollectionSearch(TestcaseBase): vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create empty partition partition_name = "search_partition_empty" - collection_w.create_partition(partition_name=partition_name, description="search partition empty") + collection_w.create_partition( + partition_name=partition_name, description="search partition empty") par = collection_w.partitions log.info("test_search_index_partition_empty: partitions: %s" % par) # 3. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "COSINE"} + default_index = {"index_type": "IVF_FLAT", "params": { + "nlist": 128}, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() # 4. search the empty partition @@ -2588,7 +2723,8 @@ class TestCollectionSearch(TestcaseBase): is_index=False, is_flush=is_flush)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance @@ -2608,7 +2744,8 @@ class TestCollectionSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"]) @@ -2626,7 +2763,8 @@ class TestCollectionSearch(TestcaseBase): is_index=False, is_flush=is_flush)[0:4] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "HAMMING"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "HAMMING"} collection_w.create_index("binary_vector", default_index) # 3. compute the distance collection_w.load() @@ -2646,7 +2784,8 @@ class TestCollectionSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("tanimoto obsolete") @@ -2666,7 +2805,8 @@ class TestCollectionSearch(TestcaseBase): is_flush=is_flush)[0:4] log.info("auto_id= %s, _async= %s" % (auto_id, _async)) # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "TANIMOTO"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "TANIMOTO"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance @@ -2686,7 +2826,8 @@ class TestCollectionSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("substructure obsolete") @@ -2708,13 +2849,15 @@ class TestCollectionSearch(TestcaseBase): = self.init_collection_general(prefix, True, default_nb, is_binary=True, auto_id=auto_id, dim=dim, is_index=False, is_flush=is_flush)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "SUBSTRUCTURE"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "SUBSTRUCTURE"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. generate search vectors _, binary_vectors = cf.gen_binary_vectors(nq, dim) # 4. search and compare the distance - search_params = {"metric_type": "SUBSTRUCTURE", "params": {"nprobe": 10}} + search_params = {"metric_type": "SUBSTRUCTURE", + "params": {"nprobe": 10}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async)[0] @@ -2744,13 +2887,15 @@ class TestCollectionSearch(TestcaseBase): = self.init_collection_general(prefix, True, default_nb, is_binary=True, auto_id=auto_id, dim=dim, is_index=False, is_flush=is_flush)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "SUPERSTRUCTURE"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "SUPERSTRUCTURE"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. generate search vectors _, binary_vectors = cf.gen_binary_vectors(nq, dim) # 4. search and compare the distance - search_params = {"metric_type": "SUPERSTRUCTURE", "params": {"nprobe": 10}} + search_params = {"metric_type": "SUPERSTRUCTURE", + "params": {"nprobe": 10}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async)[0] @@ -2768,15 +2913,19 @@ class TestCollectionSearch(TestcaseBase): expected: search successfully with limit(topK) """ # 1. initialize a collection without data - collection_w = self.init_collection_general(prefix, is_binary=True, auto_id=auto_id, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, is_binary=True, auto_id=auto_id, is_index=False)[0] # 2. insert data - insert_ids = cf.insert_data(collection_w, default_nb, is_binary=True, auto_id=auto_id)[3] + insert_ids = cf.insert_data( + collection_w, default_nb, is_binary=True, auto_id=auto_id)[3] # 3. load data - index_params = {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": metrics} + index_params = {"index_type": "BIN_FLAT", "params": { + "nlist": 128}, "metric_type": metrics} collection_w.create_index("binary_vector", index_params) collection_w.load() # 4. search - log.info("test_search_binary_without_flush: searching collection %s" % collection_w.name) + log.info("test_search_binary_without_flush: searching collection %s" % + collection_w.name) binary_vectors = cf.gen_binary_vectors(default_nq, default_dim)[1] search_params = {"metric_type": metrics, "params": {"nprobe": 10}} collection_w.search(binary_vectors[:default_nq], "binary_vector", @@ -2800,8 +2949,7 @@ class TestCollectionSearch(TestcaseBase): collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -2818,13 +2966,16 @@ class TestCollectionSearch(TestcaseBase): filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression - log.info("test_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info( + "test_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, _async=_async, @@ -2856,11 +3007,11 @@ class TestCollectionSearch(TestcaseBase): is_all_data_type=True, auto_id=auto_id, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() @@ -2881,8 +3032,10 @@ class TestCollectionSearch(TestcaseBase): # 4. search with different expressions expression = f"{default_bool_field_name} == {bool_type}" - log.info("test_search_with_expression_bool: searching with bool expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info( + "test_search_with_expression_bool: searching with bool expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, @@ -2916,8 +3069,7 @@ class TestCollectionSearch(TestcaseBase): auto_id=True, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -2925,20 +3077,25 @@ class TestCollectionSearch(TestcaseBase): filter_ids = [] for i, _id in enumerate(insert_ids): if enable_dynamic_field: - exec(f"{default_float_field_name} = _vectors[i][f'{default_float_field_name}']") + exec( + f"{default_float_field_name} = _vectors[i][f'{default_float_field_name}']") else: - exec(f"{default_float_field_name} = _vectors.{default_float_field_name}[i]") + exec( + f"{default_float_field_name} = _vectors.{default_float_field_name}[i]") if not expression or eval(expression): filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with different expressions - log.info("test_search_with_expression_auto_id: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info( + "test_search_with_expression_auto_id: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, _async=_async, @@ -2968,10 +3125,10 @@ class TestCollectionSearch(TestcaseBase): is_all_data_type=True, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_expression_all_data_type: Searching collection %s" % collection_w.name) + log.info("test_search_expression_all_data_type: Searching collection %s" % + collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] search_exp = "int64 >= 0 && int32 >= 0 && int16 >= 0 " \ "&& int8 >= 0 && float >= 0 && double >= 0" @@ -2991,7 +3148,7 @@ class TestCollectionSearch(TestcaseBase): res = res.result() assert len(res[0][0].entity._row_data) != 0 assert (default_int64_field_name and default_float_field_name and default_bool_field_name) \ - in res[0][0].entity._row_data + in res[0][0].entity._row_data @pytest.mark.tags(CaseLabel.L1) @pytest.mark.skip(reason="issue #23646") @@ -3007,17 +3164,20 @@ class TestCollectionSearch(TestcaseBase): offset = 2 ** (num - 1) default_schema = cf.gen_collection_schema_all_datatype() collection_w = self.init_collection_wrap(schema=default_schema) - collection_w = cf.insert_data(collection_w, is_all_data_type=True, insert_offset=offset-1000)[0] + collection_w = cf.insert_data( + collection_w, is_all_data_type=True, insert_offset=offset-1000)[0] # 2. create index and load collection_w.create_index(field_name, default_index_params) collection_w.load() # 3. search - log.info("test_search_expression_different_data_type: Searching collection %s" % collection_w.name) + log.info("test_search_expression_different_data_type: Searching collection %s" % + collection_w.name) expression = f"{field} >= {offset}" res = collection_w.search(vectors, default_search_field, default_search_params, - default_limit, expression, output_fields=[field], + default_limit, expression, output_fields=[ + field], check_task=CheckTasks.check_search_results, check_items={"nq": default_nq, "limit": default_limit})[0] @@ -3038,8 +3198,10 @@ class TestCollectionSearch(TestcaseBase): dim = 1 fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field(dim=dim)] - schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") - collection_w = self.init_collection_wrap(name=cf.gen_unique_str("comparison"), schema=schema) + schema = cf.gen_collection_schema( + fields=fields, primary_field="int64_1") + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str("comparison"), schema=schema) # 2. inset data values = pd.Series(data=[i for i in range(0, nb)]) @@ -3054,10 +3216,12 @@ class TestCollectionSearch(TestcaseBase): filter_ids.extend(_id) # 3. search with expression - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() expression = "int64_1 <= int64_2" - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] res = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, expression, _async=_async, @@ -3086,7 +3250,8 @@ class TestCollectionSearch(TestcaseBase): auto_id=auto_id, dim=dim)[0:4] # 2. search - log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_fields_empty: Searching collection %s" % + collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, @@ -3109,10 +3274,10 @@ class TestCollectionSearch(TestcaseBase): # 1. initialize with data collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_with_output_field: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_field: Searching collection %s" % + collection_w.name) res = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, @@ -3135,10 +3300,10 @@ class TestCollectionSearch(TestcaseBase): # 1. initialize with data collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_with_output_field: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_field: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3163,7 +3328,8 @@ class TestCollectionSearch(TestcaseBase): auto_id=auto_id, dim=dim)[0:4] # 2. search - log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_fields: Searching collection %s" % + collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] output_fields = [default_int64_field_name, default_float_field_name] collection_w.search(vectors[:nq], default_search_field, @@ -3199,10 +3365,12 @@ class TestCollectionSearch(TestcaseBase): if metrics == "COSINE": pytest.skip("COSINE does not support output vector now") # 1. create a collection and insert data - collection_w, _vectors = self.init_collection_general(prefix, True, is_index=False)[:2] + collection_w, _vectors = self.init_collection_general( + prefix, True, is_index=False)[:2] # 2. create index and load - default_index = {"index_type": index, "params": params, "metric_type": metrics} + default_index = {"index_type": index, + "params": params, "metric_type": metrics} collection_w.create_index(field_name, default_index) collection_w.load() @@ -3237,7 +3405,8 @@ class TestCollectionSearch(TestcaseBase): expected: search success """ # 1. create a collection and insert data - collection_w = self.init_collection_general(prefix, is_binary=True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, is_binary=True, is_index=False)[0] data = cf.gen_default_binary_dataframe_data()[0] collection_w.insert(data) @@ -3255,7 +3424,8 @@ class TestCollectionSearch(TestcaseBase): output_fields=[binary_field_name])[0] # 4. check the result vectors should be equal to the inserted - assert res[0][0].entity.binary_vector == [data[binary_field_name][res[0][0].id]] + assert res[0][0].entity.binary_vector == [ + data[binary_field_name][res[0][0].id]] @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [32, 128, 768]) @@ -3269,7 +3439,8 @@ class TestCollectionSearch(TestcaseBase): expected: search success """ # 1. create a collection and insert data - collection_w, _vectors = self.init_collection_general(prefix, True, dim=dim)[:2] + collection_w, _vectors = self.init_collection_general(prefix, True, dim=dim)[ + :2] # 2. search with output field vector vectors = cf.gen_vectors(default_nq, dim=dim) @@ -3296,7 +3467,8 @@ class TestCollectionSearch(TestcaseBase): enable_dynamic_field=enable_dynamic_field)[:2] # 2. search with output field vector - output_fields = [default_float_field_name, default_string_field_name, default_search_field] + output_fields = [default_float_field_name, + default_string_field_name, default_search_field] original_entities = [] if enable_dynamic_field: entities = [] @@ -3327,10 +3499,12 @@ class TestCollectionSearch(TestcaseBase): expected: search success """ # 1. initialize a collection - collection_w = self.init_collection_general(prefix, True, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, True, enable_dynamic_field=enable_dynamic_field)[0] # 2. search with output field vector - output_fields = [default_int64_field_name, default_string_field_name, default_search_field] + output_fields = [default_int64_field_name, + default_string_field_name, default_search_field] collection_w.search(vectors[:1], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=output_fields, @@ -3382,8 +3556,10 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] # 2. search - log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name) - output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields) + log.info("test_search_with_output_field_wildcard: Searching collection %s" % + collection_w.name) + output_fields = cf.get_wildcard_output_field_names( + collection_w, wildcard_output_fields) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3404,9 +3580,11 @@ class TestCollectionSearch(TestcaseBase): expected: search success """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general( + prefix, True, auto_id=auto_id)[0:4] # 2. search - log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_field_wildcard: Searching collection %s" % + collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -3431,7 +3609,8 @@ class TestCollectionSearch(TestcaseBase): auto_id=auto_id, dim=dim)[0:4] # 2. search - vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(nq)] log.info("test_search_multi_collections: searching %s entities (nq = %s) from collection %s" % (default_limit, nq, collection_w.name)) collection_w.search(vectors[:nq], default_search_field, @@ -3456,11 +3635,11 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] def search(collection_w): - vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3471,7 +3650,8 @@ class TestCollectionSearch(TestcaseBase): "_async": _async}) # 2. search with multi-processes - log.info("test_search_concurrent_multi_threads: searching with %s processes" % threads_num) + log.info( + "test_search_concurrent_multi_threads: searching with %s processes" % threads_num) for i in range(threads_num): t = threading.Thread(target=search, args=(collection_w,)) threads.append(t) @@ -3490,15 +3670,18 @@ class TestCollectionSearch(TestcaseBase): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() def do_insert(): df = cf.gen_default_dataframe_data(10000) for i in range(11): collection_w.insert(df) - log.info(f'Collection num entities is : {collection_w.num_entities}') + log.info( + f'Collection num entities is : {collection_w.num_entities}') def do_search(): while True: @@ -3512,7 +3695,8 @@ class TestCollectionSearch(TestcaseBase): timeout=30) p_insert = multiprocessing.Process(target=do_insert, args=()) - p_search = multiprocessing.Process(target=do_search, args=(), daemon=True) + p_search = multiprocessing.Process( + target=do_search, args=(), daemon=True) p_insert.start() p_search.start() @@ -3535,7 +3719,8 @@ class TestCollectionSearch(TestcaseBase): collection_w = self.init_collection_general(prefix, True, nb=tmp_nb, enable_dynamic_field=enable_dynamic_field)[0] # 2. search - log.info("test_search_round_decimal: Searching collection %s" % collection_w.name) + log.info("test_search_round_decimal: Searching collection %s" % + collection_w.name) res, _ = collection_w.search(vectors[:tmp_nq], default_search_field, default_search_params, tmp_limit) @@ -3549,7 +3734,8 @@ class TestCollectionSearch(TestcaseBase): dis_actual = res_round[0][i].distance # log.debug(f'actual: {dis_actual}, expect: {dis_expect}') # abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - assert math.isclose(dis_actual, dis_expect, rel_tol=0, abs_tol=abs_tol) + assert math.isclose(dis_actual, dis_expect, + rel_tol=0, abs_tol=abs_tol) @pytest.mark.tags(CaseLabel.L1) def test_search_with_expression_large(self, dim, enable_dynamic_field): @@ -3563,18 +3749,19 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field, + enable_dynamic_field=enable_dynamic_field, with_json=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression expression = f"0 < {default_int64_field_name} < 5001" - log.info("test_search_with_expression: searching with expression: %s" % expression) + log.info( + "test_search_with_expression: searching with expression: %s" % expression) nums = 5000 vectors = [[random.random() for _ in range(dim)] for _ in range(nums)] @@ -3582,9 +3769,9 @@ class TestCollectionSearch(TestcaseBase): default_search_params, default_limit, expression, check_task=CheckTasks.check_search_results, check_items={ - "nq": nums, - "ids": insert_ids, - "limit": default_limit, + "nq": nums, + "ids": insert_ids, + "limit": default_limit, }) @pytest.mark.tags(CaseLabel.L1) @@ -3603,7 +3790,8 @@ class TestCollectionSearch(TestcaseBase): with_json=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() @@ -3615,9 +3803,9 @@ class TestCollectionSearch(TestcaseBase): default_search_params, default_limit, expression, check_task=CheckTasks.check_search_results, check_items={ - "nq": nums, - "ids": insert_ids, - "limit": default_limit, + "nq": nums, + "ids": insert_ids, + "limit": default_limit, }) @pytest.mark.tags(CaseLabel.L1) @@ -3634,8 +3822,7 @@ class TestCollectionSearch(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, nb_old, auto_id=auto_id, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search for original data after load vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, @@ -3649,7 +3836,8 @@ class TestCollectionSearch(TestcaseBase): }) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_BOUNDED) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_BOUNDED) kwargs.update({"consistency_level": consistency_level}) nb_new = 400 @@ -3743,7 +3931,8 @@ class TestCollectionSearch(TestcaseBase): enable_dynamic_field=enable_dynamic_field) insert_ids.extend(insert_ids_new) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_EVENTUALLY) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_EVENTUALLY) kwargs.update({"consistency_level": consistency_level}) collection_w.search(vectors[:nq], default_search_field, default_search_params, limit, @@ -3777,7 +3966,8 @@ class TestCollectionSearch(TestcaseBase): "_async": _async}) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_SESSION) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_SESSION) kwargs.update({"consistency_level": consistency_level}) nb_new = 400 @@ -3813,7 +4003,8 @@ class TestCollectionSearch(TestcaseBase): collection_w.insert(data) # 3. search with param ignore_growing=True - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}, "ignore_growing": True} + search_params = {"metric_type": "COSINE", "params": { + "nprobe": 10}, "ignore_growing": True} vector = [[random.random() for _ in range(dim)] for _ in range(nq)] res = collection_w.search(vector[:nq], default_search_field, search_params, default_limit, default_search_exp, _async=_async, @@ -3879,7 +4070,8 @@ class TestCollectionSearch(TestcaseBase): self._connect() fields = [cf.gen_int64_field(), cf.gen_int64_field(field_name1), cf.gen_float_vec_field(field_name2, dim=default_dim)] - schema = cf.gen_collection_schema(fields=fields, primary_field=default_int64_field_name) + schema = cf.gen_collection_schema( + fields=fields, primary_field=default_int64_field_name) collection_w = self.init_collection_wrap(name=collection_name, schema=schema, check_task=CheckTasks.check_collection_property, check_items={"name": collection_name, "schema": schema}) @@ -3889,7 +4081,8 @@ class TestCollectionSearch(TestcaseBase): dataframe = pd.DataFrame({default_int64_field_name: int_values, field_name1: int_values, field_name2: float_vec_values}) collection_w.insert(dataframe) - collection_w.create_index(field_name2, index_params=ct.default_flat_index) + collection_w.create_index( + field_name2, index_params=ct.default_flat_index) collection_w.load() collection_w.search(vectors[:default_nq], field_name2, default_search_params, default_limit, _async=_async, @@ -3923,7 +4116,8 @@ class TestCollectionSearch(TestcaseBase): collection_w.load() vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, - default_limit, default_search_exp, [partition_name], + default_limit, default_search_exp, [ + partition_name], check_task=CheckTasks.check_search_results, check_items={"nq": nq, "ids": insert_ids, @@ -3953,11 +4147,13 @@ class TestCollectionSearch(TestcaseBase): collection_w.create_partition(partition_name) insert_ids = cf.insert_data(collection_w, nb, auto_id=auto_id, dim=dim, enable_dynamic_field=enable_dynamic_field)[3] - collection_w.create_index(default_search_field, default_index_params, index_name=index_name) + collection_w.create_index( + default_search_field, default_index_params, index_name=index_name) collection_w.load() vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, - default_limit, default_search_exp, [partition_name], + default_limit, default_search_exp, [ + partition_name], check_task=CheckTasks.check_search_results, check_items={"nq": nq, "ids": insert_ids, @@ -3976,8 +4172,10 @@ class TestCollectionSearch(TestcaseBase): nq = 5 upsert_nb = 1000 collection_w = self.init_collection_general(prefix, True)[0] - vectors = [[random.random() for _ in range(default_dim)] for _ in range(nq)] - res1 = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit)[0] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(nq)] + res1 = collection_w.search( + vectors[:nq], default_search_field, default_search_params, default_limit)[0] def do_upsert(): data = cf.gen_default_data_for_upsert(upsert_nb)[0] @@ -3985,9 +4183,11 @@ class TestCollectionSearch(TestcaseBase): t = threading.Thread(target=do_upsert, args=()) t.start() - res2 = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit)[0] + res2 = collection_w.search( + vectors[:nq], default_search_field, default_search_params, default_limit)[0] t.join() - assert [res1[i].ids for i in range(nq)] == [res2[i].ids for i in range(nq)] + assert [res1[i].ids for i in range(nq)] == [ + res2[i].ids for i in range(nq)] @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("not support default_value now") @@ -4073,7 +4273,8 @@ class TestSearchBase(TestcaseBase): """ top_k = 16385 # max top k is 16384 nq = get_nq - collection_w, data, _, insert_ids = self.init_collection_general(prefix, insert_data=True, nb=nq)[0:4] + collection_w, data, _, insert_ids = self.init_collection_general( + prefix, insert_data=True, nb=nq)[0:4] collection_w.load() if top_k <= max_top_k: res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, @@ -4107,7 +4308,8 @@ class TestSearchBase(TestcaseBase): vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create partition partition_name = "search_partition_empty" - collection_w.create_partition(partition_name=partition_name, description="search partition empty") + collection_w.create_partition( + partition_name=partition_name, description="search partition empty") par = collection_w.partitions # collection_w.load() # 3. create different index @@ -4117,7 +4319,8 @@ class TestSearchBase(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, + "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4165,7 +4368,8 @@ class TestSearchBase(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, + "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() res, _ = collection_w.search(vectors[:nq], default_search_field, @@ -4188,7 +4392,8 @@ class TestSearchBase(TestcaseBase): dim=dim, is_index=False)[0:5] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create ip index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "IP"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() search_params = {"metric_type": "IP", "params": {"nprobe": 10}} @@ -4216,7 +4421,8 @@ class TestSearchBase(TestcaseBase): dim=dim, is_index=False)[0:5] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create ip index - default_index = {"index_type": index, "params": params, "metric_type": "IP"} + default_index = {"index_type": index, + "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() search_params = {"metric_type": "IP", "params": {"nprobe": 10}} @@ -4278,11 +4484,13 @@ class TestSearchBase(TestcaseBase): vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create partition partition_name = "search_partition_empty" - collection_w.create_partition(partition_name=partition_name, description="search partition empty") + collection_w.create_partition( + partition_name=partition_name, description="search partition empty") par = collection_w.partitions # collection_w.load() # 3. create different index - default_index = {"index_type": index, "params": params, "metric_type": "IP"} + default_index = {"index_type": index, + "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4324,7 +4532,8 @@ class TestSearchBase(TestcaseBase): par_name = collection_w.partitions[0].name # collection_w.load() # 3. create different index - default_index = {"index_type": index, "params": params, "metric_type": "IP"} + default_index = {"index_type": index, + "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4347,7 +4556,8 @@ class TestSearchBase(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, ct.default_nq)[0:5] - vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(ct.default_dim)] + for _ in range(nq)] collection_w.load() self.connection_wrap.remove_connection(ct.default_alias) @@ -4371,10 +4581,12 @@ class TestSearchBase(TestcaseBase): """ threads_num = 10 threads = [] - collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, ct.default_nb)[0:5] + collection_w, _, _, insert_ids, time_stamp = self.init_collection_general( + prefix, True, ct.default_nb)[0:5] def search(collection_w): - vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(ct.default_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -4385,7 +4597,8 @@ class TestSearchBase(TestcaseBase): "_async": _async}) # 2. search with multi-processes - log.info("test_search_concurrent_multi_threads: searching with %s processes" % threads_num) + log.info( + "test_search_concurrent_multi_threads: searching with %s processes" % threads_num) for i in range(threads_num): t = threading.Thread(target=search, args=(collection_w,)) threads.append(t) @@ -4408,9 +4621,11 @@ class TestSearchBase(TestcaseBase): for i in range(num): collection = gen_unique_str(uid + str(i)) collection_w, _, _, insert_ids, time_stamp = \ - self.init_collection_general(collection, True, ct.default_nb)[0:5] + self.init_collection_general( + collection, True, ct.default_nb)[0:5] assert len(insert_ids) == default_nb - vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(ct.default_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, top_k, default_search_exp, @@ -4430,7 +4645,8 @@ class TestSearchDSL(TestcaseBase): """ collection_w, _, _, insert_ids, time_stamp = \ self.init_collection_general(prefix, True, ct.default_nb)[0:5] - vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(ct.default_dim)] + for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, ct.default_top_k, default_search_exp, @@ -4487,8 +4703,10 @@ class TestSearchString(TestcaseBase): self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_string_field_not_primary: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_string_field_not_primary: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, @@ -4515,8 +4733,10 @@ class TestSearchString(TestcaseBase): self.init_collection_general(prefix, True, dim=dim, primary_field=ct.default_string_field_name, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info("test_search_string_field_is_primary_true: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, @@ -4545,9 +4765,12 @@ class TestSearchString(TestcaseBase): collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. search - log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name) - range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info("test_search_string_field_is_primary_true: searching collection %s" % + collection_w.name) + range_search_params = {"metric_type": "L2", + "params": {"radius": 1000, "range_filter": 0}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, @@ -4574,8 +4797,10 @@ class TestSearchString(TestcaseBase): self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_string_mix_expr: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info("test_search_string_mix_expr: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, @@ -4599,10 +4824,13 @@ class TestSearchString(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0:4] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=default_dim)[0:4] # 2. search - log.info("test_search_string_with_invalid_expr: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_string_with_invalid_expr: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_invaild_string_exp, @@ -4624,8 +4852,7 @@ class TestSearchString(TestcaseBase): collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -4642,13 +4869,16 @@ class TestSearchString(TestcaseBase): filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression - log.info("test_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info( + "test_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, _async=_async, @@ -4683,7 +4913,8 @@ class TestSearchString(TestcaseBase): is_index=False, primary_field=ct.default_string_field_name)[0:4] # 2. create index - default_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": "BIN_IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. search with exception @@ -4699,7 +4930,6 @@ class TestSearchString(TestcaseBase): "limit": 2, "_async": _async}) - @pytest.mark.tags(CaseLabel.L2) def test_search_string_field_binary(self, auto_id, dim, _async): """ @@ -4716,7 +4946,8 @@ class TestSearchString(TestcaseBase): dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": "BIN_IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 2. search with exception @@ -4731,7 +4962,6 @@ class TestSearchString(TestcaseBase): "limit": 2, "_async": _async}) - @pytest.mark.tags(CaseLabel.L2) def test_search_mix_expr_with_binary(self, dim, auto_id, _async): """ @@ -4743,13 +4973,16 @@ class TestSearchString(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_binary=True, is_index=False)[0:4] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_binary=True, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": "BIN_IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 2. search - log.info("test_search_mix_expr_with_binary: searching collection %s" % collection_w.name) + log.info("test_search_mix_expr_with_binary: searching collection %s" % + collection_w.name) binary_vectors = cf.gen_binary_vectors(3000, dim)[1] search_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}} output_fields = [default_string_field_name, default_float_field_name] @@ -4775,19 +5008,24 @@ class TestSearchString(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim, is_index=False)[0:4] - index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=default_dim, is_index=False)[0:4] + index_param = {"index_type": "IVF_FLAT", + "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param, index_name="a") index_param_two = {} collection_w.create_index("varchar", index_param_two, index_name="b") collection_w.load() # 2. search - log.info("test_search_string_field_not_primary: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_string_field_not_primary: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] output_fields = [default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, # search all buckets - {"metric_type": "L2", "params": {"nprobe": 100}}, default_limit, + {"metric_type": "L2", "params": { + "nprobe": 100}}, default_limit, perfix_expr, output_fields=output_fields, _async=_async, @@ -4813,10 +5051,13 @@ class TestSearchString(TestcaseBase): is_index=False)[0:4] # create index - index_params_one = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} - collection_w.create_index(ct.default_float_vec_field_name, index_params_one, index_name=index_name1) + index_params_one = {"index_type": "IVF_SQ8", + "metric_type": "COSINE", "params": {"nlist": 64}} + collection_w.create_index( + ct.default_float_vec_field_name, index_params_one, index_name=index_name1) index_params_two = {} - collection_w.create_index(ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) + collection_w.create_index( + ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) assert collection_w.has_index(index_name=index_name2) collection_w.release() @@ -4824,8 +5065,10 @@ class TestSearchString(TestcaseBase): # delete entity expr = 'float >= int64' # search with id 0 vectors - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expr, @@ -4848,7 +5091,8 @@ class TestSearchString(TestcaseBase): """ # 1. initialize with data collection_w, _, _, _ = \ - self.init_collection_general(prefix, False, primary_field=ct.default_string_field_name)[0:4] + self.init_collection_general( + prefix, False, primary_field=ct.default_string_field_name)[0:4] nb = 3000 data = cf.gen_default_list_data(nb) @@ -4861,8 +5105,10 @@ class TestSearchString(TestcaseBase): limit = 1 # 2. search - log.info("test_search_string_field_is_primary_true: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_string_field_is_primary_true: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, @@ -4884,8 +5130,9 @@ class TestSearchString(TestcaseBase): expected: Search successfully """ # 1. initialize with data - collection_w, _, _, _= \ - self.init_collection_general(prefix, False, primary_field=ct.default_int64_field_name, is_index=False)[0:4] + collection_w, _, _, _ = \ + self.init_collection_general( + prefix, False, primary_field=ct.default_int64_field_name, is_index=False)[0:4] nb = 3000 data = cf.gen_default_list_data(nb) @@ -4896,15 +5143,18 @@ class TestSearchString(TestcaseBase): assert collection_w.num_entities == nb # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() search_string_exp = "varchar >= \"\"" # 3. search - log.info("test_search_string_field_not_primary: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info("test_search_string_field_not_primary: searching collection %s" % + collection_w.name) + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, @@ -4958,8 +5208,10 @@ class TestSearchPagination(TestcaseBase): collection_w = self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim, enable_dynamic_field=enable_dynamic_field)[0] # 2. search pagination with offset - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] search_res = collection_w.search(vectors[:default_nq], default_search_field, search_param, limit, default_search_exp, _async=_async, @@ -4994,8 +5246,10 @@ class TestSearchPagination(TestcaseBase): self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] output_fields = [default_string_field_name, default_float_field_name] search_res = collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, @@ -5031,9 +5285,11 @@ class TestSearchPagination(TestcaseBase): """ # 1. create a collection collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, is_binary=True, auto_id=auto_id, dim=default_dim)[0:4] + self.init_collection_general( + prefix, True, is_binary=True, auto_id=auto_id, dim=default_dim)[0:4] # 2. search - search_param = {"metric_type": "JACCARD", "params": {"nprobe": 10}, "offset": offset} + search_param = {"metric_type": "JACCARD", + "params": {"nprobe": 10}, "offset": offset} binary_vectors = cf.gen_binary_vectors(default_nq, default_dim)[1] search_res = collection_w.search(binary_vectors[:default_nq], "binary_vector", search_param, default_limit, @@ -5042,12 +5298,14 @@ class TestSearchPagination(TestcaseBase): "ids": insert_ids, "limit": default_limit})[0] # 3. search with offset+limit - search_binary_param = {"metric_type": "JACCARD", "params": {"nprobe": 10}} + search_binary_param = { + "metric_type": "JACCARD", "params": {"nprobe": 10}} res = collection_w.search(binary_vectors[:default_nq], "binary_vector", search_binary_param, default_limit + offset)[0] assert len(search_res[0].ids) == len(res[0].ids[offset:]) - assert sorted(search_res[0].distances, key=numpy.float32) == sorted(res[0].distances[offset:], key=numpy.float32) + assert sorted(search_res[0].distances, key=numpy.float32) == sorted( + res[0].distances[offset:], key=numpy.float32) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("limit", [100, 3000, 10000]) @@ -5063,10 +5321,13 @@ class TestSearchPagination(TestcaseBase): # 1. create a collection topK = 16384 offset = topK - limit - collection_w = self.init_collection_general(prefix, True, nb=20000, auto_id=auto_id, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, nb=20000, auto_id=auto_id, dim=default_dim)[0] # 2. search - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] search_res = collection_w.search(vectors[:default_nq], default_search_field, search_param, limit, default_search_exp, _async=_async, @@ -5099,8 +5360,7 @@ class TestSearchPagination(TestcaseBase): dim = 8 collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb=nb, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] expression = expression.replace("&&", "and").replace("||", "or") @@ -5121,8 +5381,10 @@ class TestSearchPagination(TestcaseBase): limit = 0 elif len(filter_ids) - offset < default_limit: limit = len(filter_ids) - offset - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, expression, _async=_async, @@ -5159,15 +5421,18 @@ class TestSearchPagination(TestcaseBase): partition_num=1, auto_id=auto_id, is_index=False)[0:4] - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] # 2. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search through partitions par = collection_w.partitions limit = 100 - search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset} + search_params = {"metric_type": "L2", + "params": {"nprobe": 10}, "offset": offset} search_res = collection_w.search(vectors[:default_nq], default_search_field, search_params, limit, default_search_exp, [par[0].name, par[1].name], _async=_async, @@ -5202,12 +5467,14 @@ class TestSearchPagination(TestcaseBase): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, partition_num=1, auto_id=auto_id)[0:4] - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.load() # 2. search through partitions par = collection_w.partitions limit = 1000 - search_param = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": offset} + search_param = {"metric_type": "L2", + "params": {"nprobe": 10}, "offset": offset} search_res = collection_w.search(vectors[:default_nq], default_search_field, search_param, limit, default_search_exp, [par[0].name, par[1].name], _async=_async, @@ -5239,7 +5506,8 @@ class TestSearchPagination(TestcaseBase): expected: searched successfully """ # 1. create collection - collection_w = self.init_collection_general(prefix, False, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, False, dim=default_dim)[0] # 2. insert data data = cf.gen_default_dataframe_data(dim=default_dim) collection_w.insert(data) @@ -5274,9 +5542,11 @@ class TestSearchPagination(TestcaseBase): expected: search successfully """ # 1. initialize without data - collection_w = self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, auto_id=auto_id, dim=default_dim)[0] # 2. search collection without data - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} search_res = collection_w.search([], default_search_field, search_param, default_limit, default_search_exp, _async=_async, check_task=CheckTasks.check_search_results, @@ -5296,10 +5566,13 @@ class TestSearchPagination(TestcaseBase): expected: return an empty list """ # 1. initialize - collection_w = self.init_collection_general(prefix, True, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, dim=default_dim)[0] # 2. search - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] res = collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, @@ -5331,12 +5604,14 @@ class TestSearchPagination(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: res = collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit + offset, default_search_exp, _async=_async)[0] @@ -5370,7 +5645,8 @@ class TestSearchPagination(TestcaseBase): # 1. initialize collection_w = self.init_collection_general(prefix, True)[0] # 2. search with offset in params - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} + search_params = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} res1 = collection_w.search(vectors[:default_nq], default_search_field, search_params, default_limit)[0] @@ -5402,10 +5678,13 @@ class TestSearchPaginationInvalid(TestcaseBase): expected: raise exception """ # 1. initialize - collection_w = self.init_collection_general(prefix, True, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, dim=default_dim)[0] # 2. search - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, @@ -5422,10 +5701,13 @@ class TestSearchPaginationInvalid(TestcaseBase): expected: raise exception """ # 1. initialize - collection_w = self.init_collection_general(prefix, True, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, dim=default_dim)[0] # 2. search - search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", + "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, @@ -5449,7 +5731,7 @@ class TestSearchDiskann(TestcaseBase): def auto_id(self, request): yield request.param - @pytest.fixture(scope="function", params=[False ,True]) + @pytest.fixture(scope="function", params=[False, True]) def _async(self, request): yield request.param @@ -5473,15 +5755,20 @@ class TestSearchDiskann(TestcaseBase): nb=nb, dim=dim, is_index=False, enable_dynamic_field=enable_dynamic_field)[0:4] - + # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", "params": {"search_list": 30}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = { + "metric_type": "L2", "params": {"search_list": 30}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -5507,16 +5794,22 @@ class TestSearchDiskann(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", + "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, - default_search_params, limit, + default_search_params, limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, @@ -5537,16 +5830,22 @@ class TestSearchDiskann(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", + "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, - default_search_params, limit, + default_search_params, limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, @@ -5567,16 +5866,22 @@ class TestSearchDiskann(TestcaseBase): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general( + prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params ={"metric_type": "L2", "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", + "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, - default_search_params, limit, + default_search_params, limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, @@ -5598,13 +5903,18 @@ class TestSearchDiskann(TestcaseBase): primary_field=ct.default_string_field_name, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() search_list = 20 - default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", + "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -5629,8 +5939,10 @@ class TestSearchDiskann(TestcaseBase): self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() tmp_expr = f'{ct.default_int64_field_name} in {[0]}' @@ -5641,9 +5953,12 @@ class TestSearchDiskann(TestcaseBase): assert del_res.delete_count == half_nb collection_w.delete(tmp_expr) - default_search_params ={"metric_type": "L2", "params": {"search_list": 30}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = { + "metric_type": "L2", "params": {"search_list": 30}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -5670,14 +5985,18 @@ class TestSearchDiskann(TestcaseBase): self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "COSINE", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index, index_name=index_name1) + default_index = {"index_type": "DISKANN", + "metric_type": "COSINE", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index, index_name=index_name1) if not enable_dynamic_field: index_params_one = {} - collection_w.create_index("float", index_params_one, index_name="a") + collection_w.create_index( + "float", index_params_one, index_name="a") index_param_two = {} - collection_w.create_index("varchar", index_param_two, index_name="b") - + collection_w.create_index( + "varchar", index_param_two, index_name="b") + collection_w.load() tmp_expr = f'{ct.default_int64_field_name} in {[0]}' @@ -5688,9 +6007,12 @@ class TestSearchDiskann(TestcaseBase): assert del_res.delete_count == half_nb collection_w.delete(tmp_expr) - default_search_params = {"metric_type": "COSINE", "params": {"search_list": 30}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = { + "metric_type": "COSINE", "params": {"search_list": 30}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, @@ -5702,7 +6024,7 @@ class TestSearchDiskann(TestcaseBase): "limit": default_limit, "_async": _async} ) - + @pytest.mark.tags(CaseLabel.L1) def test_search_with_scalar_field(self, dim, _async, enable_dynamic_field): """ @@ -5714,23 +6036,31 @@ class TestSearchDiskann(TestcaseBase): """ # 1. initialize with data collection_w, _, _, ids = \ - self.init_collection_general(prefix, True, dim=dim, primary_field=ct.default_string_field_name, + self.init_collection_general(prefix, True, dim=dim, primary_field=ct.default_string_field_name, is_index=False, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - default_index = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "IVF_SQ8", + "metric_type": "COSINE", "params": {"nlist": 64}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) index_params = {} if not enable_dynamic_field: - collection_w.create_index(ct.default_float_field_name, index_params=index_params) - collection_w.create_index(ct.default_int64_field_name, index_params=index_params) + collection_w.create_index( + ct.default_float_field_name, index_params=index_params) + collection_w.create_index( + ct.default_int64_field_name, index_params=index_params) else: - collection_w.create_index(ct.default_string_field_name, index_params=index_params) + collection_w.create_index( + ct.default_string_field_name, index_params=index_params) collection_w.load() default_expr = "int64 in [1, 2, 3, 4]" limit = 4 - default_search_params = {"metric_type": "COSINE", "params": {"nprobe": 64}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + default_search_params = { + "metric_type": "COSINE", "params": {"nprobe": 64}} + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] search_res = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, default_expr, output_fields=output_fields, _async=_async, @@ -5756,13 +6086,17 @@ class TestSearchDiskann(TestcaseBase): enable_dynamic_field=enable_dynamic_field)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() search_params = {"metric_type": "L2", "params": {"search_list": limit}} - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, search_params, limit, default_search_exp, @@ -5790,13 +6124,18 @@ class TestSearchDiskann(TestcaseBase): dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} - collection_w.create_index(ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", + "metric_type": "L2", "params": {}} + collection_w.create_index( + ct.default_float_vec_field_name, default_index) collection_w.load() - search_params = {"metric_type": "L2", "params": {"k": 200, "search_list": 201}} - search_vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] - output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] + search_params = {"metric_type": "L2", + "params": {"k": 200, "search_list": 201}} + search_vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] + output_fields = [default_int64_field_name, + default_float_field_name, default_string_field_name] collection_w.search(search_vectors[:default_nq], default_search_field, search_params, default_limit, default_search_exp, @@ -5900,12 +6239,14 @@ class TestCollectionRangeSearch(TestcaseBase): expected: range search successfully as normal search """ # 1. initialize with data - collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb=10)[0:5] + collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general( + prefix, True, nb=10)[0:5] # 2. get vectors that inserted into collection vectors = np.array(_vectors[0]).tolist() vectors = [vectors[i][-1] for i in range(default_nq)] # 3. range search with L2 - range_search_params = {"metric_type": "COSINE", "params": {"range_filter": 1}} + range_search_params = {"metric_type": "COSINE", + "params": {"range_filter": 1}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -5914,7 +6255,8 @@ class TestCollectionRangeSearch(TestcaseBase): "ids": insert_ids, "limit": default_limit}) # 4. range search with IP - range_search_params = {"metric_type": "IP", "params": {"range_filter": 1}} + range_search_params = {"metric_type": "IP", + "params": {"range_filter": 1}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -5930,7 +6272,8 @@ class TestCollectionRangeSearch(TestcaseBase): expected: search successfully with filtered limit(topK) """ # 1. initialize with data - collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb=10, is_index=False)[0:5] + collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general( + prefix, True, nb=10, is_index=False)[0:5] collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. get vectors that inserted into collection @@ -5962,12 +6305,14 @@ class TestCollectionRangeSearch(TestcaseBase): expected: search successfully as normal search """ # 1. initialize with data - collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb=10)[0:5] + collection_w, _vectors, _, insert_ids, time_stamp = self.init_collection_general( + prefix, True, nb=10)[0:5] # 2. get vectors that inserted into collection vectors = np.array(_vectors[0]).tolist() vectors = [vectors[i][-1] for i in range(default_nq)] # 3. range search with L2 - range_search_params = {"metric_type": "COSINE", "radius": 0, "range_filter": 1} + range_search_params = {"metric_type": "COSINE", + "radius": 0, "range_filter": 1} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -5976,7 +6321,8 @@ class TestCollectionRangeSearch(TestcaseBase): "ids": insert_ids, "limit": default_limit}) # 4. range search with IP - range_search_params = {"metric_type": "IP", "radius": 1, "range_filter": 0} + range_search_params = {"metric_type": "IP", + "radius": 1, "range_filter": 0} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, @@ -6005,7 +6351,8 @@ class TestCollectionRangeSearch(TestcaseBase): vectors = np.array(insert_data[0]).tolist() vectors = [vectors[i][-1] for i in range(default_nq)] log.info(vectors) - range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} + range_search_params = {"metric_type": "COSINE", "params": { + "nprobe": 10, "radius": 0, "range_filter": 1000}} search_res = collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, _async=_async, @@ -6031,7 +6378,8 @@ class TestCollectionRangeSearch(TestcaseBase): """ # 1. create a collection, insert data and flush nb = 10 - collection_w = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, nb, dim=dim, is_index=False)[0] # 2. insert data and flush again for two segments data = cf.gen_default_dataframe_data(nb=nb, dim=dim, start=nb) @@ -6039,11 +6387,13 @@ class TestCollectionRangeSearch(TestcaseBase): collection_w.flush() # 3. create index and load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. get inserted original data - inserted_vectors = collection_w.query(expr="int64 >= 0", output_fields=[ct.default_float_vec_field_name]) + inserted_vectors = collection_w.query(expr="int64 >= 0", output_fields=[ + ct.default_float_vec_field_name]) original_vectors = [] for single in inserted_vectors[0]: single_vector = single[ct.default_float_vec_field_name] @@ -6060,7 +6410,8 @@ class TestCollectionRangeSearch(TestcaseBase): distances_index_max = map(distances.index, distances_max) # 6. Search - range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1}} + range_search_params = {"metric_type": "COSINE", "params": { + "nprobe": 10, "radius": 0, "range_filter": 1}} collection_w.search(vectors, default_search_field, range_search_params, limit, check_task=CheckTasks.check_search_results, @@ -6078,11 +6429,13 @@ class TestCollectionRangeSearch(TestcaseBase): expected: search successfully with 0 results """ # 1. initialize without data - collection_w = self.init_collection_general(prefix, True, dim=default_dim)[0] + collection_w = self.init_collection_general( + prefix, True, dim=default_dim)[0] # 2. search collection without data log.info("test_range_search_with_empty_vectors: Range searching collection %s " "using empty vector" % collection_w.name) - range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 0}} + range_search_params = {"metric_type": "COSINE", "params": { + "nprobe": 10, "radius": 0, "range_filter": 0}} collection_w.search([], default_search_field, range_search_params, default_limit, default_search_exp, _async=_async, check_task=CheckTasks.check_search_results, @@ -6109,7 +6462,8 @@ class TestCollectionRangeSearch(TestcaseBase): dim=dim)[0:4] # 2. search all the partitions before partition deletion vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] - log.info("test_range_search_before_after_delete: searching before deleting partitions") + log.info( + "test_range_search_before_after_delete: searching before deleting partitions") range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:nq], default_search_field, @@ -6129,10 +6483,12 @@ class TestCollectionRangeSearch(TestcaseBase): collection_w.release(par[partition_num].name) collection_w.drop_partition(par[partition_num].name) log.info("test_range_search_before_after_delete: deleted a partition") - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. search non-deleted part after delete partitions - log.info("test_range_search_before_after_delete: searching after deleting partitions") + log.info( + "test_range_search_before_after_delete: searching after deleting partitions") range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:nq], default_search_field, @@ -6158,17 +6514,21 @@ class TestCollectionRangeSearch(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, default_nb, 1, auto_id=auto_id, dim=default_dim, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. release collection - log.info("test_range_search_collection_after_release_load: releasing collection %s" % collection_w.name) + log.info("test_range_search_collection_after_release_load: releasing collection %s" % + collection_w.name) collection_w.release() - log.info("test_range_search_collection_after_release_load: released collection %s" % collection_w.name) + log.info("test_range_search_collection_after_release_load: released collection %s" % + collection_w.name) # 3. Search the pre-released collection after load - log.info("test_range_search_collection_after_release_load: loading collection %s" % collection_w.name) + log.info("test_range_search_collection_after_release_load: loading collection %s" % + collection_w.name) collection_w.load() - log.info("test_range_search_collection_after_release_load: searching after load") - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + log.info( + "test_range_search_collection_after_release_load: searching after load") + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, @@ -6189,17 +6549,21 @@ class TestCollectionRangeSearch(TestcaseBase): expected: search success with limit(topK) """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, dim=dim, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, dim=dim, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data - insert_ids = cf.insert_data(collection_w, default_nb, dim=dim, enable_dynamic_field=enable_dynamic_field)[3] + insert_ids = cf.insert_data( + collection_w, default_nb, dim=dim, enable_dynamic_field=enable_dynamic_field)[3] # 3. load data - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 4. flush and load collection_w.num_entities collection_w.load() # 5. search - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:default_nq], default_search_field, @@ -6226,8 +6590,7 @@ class TestCollectionRangeSearch(TestcaseBase): nb_old = 500 collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb_old, dim=dim, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. search for original data after load vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] range_search_params = {"metric_type": "COSINE", "params": {"radius": 0, @@ -6271,17 +6634,20 @@ class TestCollectionRangeSearch(TestcaseBase): """ # 1. connect, create collection and insert data self._connect() - collection_w = self.init_collection_general(prefix, False, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, False, dim=dim, is_index=False)[0] dataframe = cf.gen_default_dataframe_data(dim=dim, start=-1500) collection_w.insert(dataframe) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) # 3. load and range search collection_w.load() - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}} collection_w.search(vectors[:default_nq], default_search_field, @@ -6304,19 +6670,24 @@ class TestCollectionRangeSearch(TestcaseBase): self._connect() # 1. create collection name = cf.gen_unique_str(prefix) - collection_w = self.init_collection_wrap(name=name, shards_num=shards_num) + collection_w = self.init_collection_wrap( + name=name, shards_num=shards_num) # 2. rename collection new_collection_name = cf.gen_unique_str(prefix + "new") - self.utility_wrap.rename_collection(collection_w.name, new_collection_name) - collection_w = self.init_collection_wrap(name=new_collection_name, shards_num=shards_num) + self.utility_wrap.rename_collection( + collection_w.name, new_collection_name) + collection_w = self.init_collection_wrap( + name=new_collection_name, shards_num=shards_num) # 3. insert dataframe = cf.gen_default_dataframe_data() collection_w.insert(dataframe) # 4. create index and load - collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index( + ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 5. range search - vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] + for _ in range(default_nq)] range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:default_nq], default_search_field, @@ -6341,8 +6712,7 @@ class TestCollectionRangeSearch(TestcaseBase): collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, 5000, partition_num=1, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:5] + enable_dynamic_field=enable_dynamic_field)[0:5] # 2. create index and load if params.get("m"): if (dim % params["m"]) != 0: @@ -6350,12 +6720,14 @@ class TestCollectionRangeSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, "params": params, "metric_type": "L2"} + default_index = {"index_type": index, + "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. range search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 1000 search_param["params"]["range_filter"] = 0 @@ -6391,14 +6763,18 @@ class TestCollectionRangeSearch(TestcaseBase): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info("test_range_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, "params": params, "metric_type": "IP"} + log.info( + "test_range_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, + "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) - log.info("test_range_search_after_index_different_metric_type: Created index-%s" % index) + log.info( + "test_range_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "IP") - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 0 search_param["params"]["range_filter"] = 1000 @@ -6427,18 +6803,21 @@ class TestCollectionRangeSearch(TestcaseBase): is_index=False)[0:5] # 2. create index - default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", + "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search in one partition - log.info("test_range_search_index_one_partition: searching (1000 entities) through one partition") + log.info( + "test_range_search_index_one_partition: searching (1000 entities) through one partition") limit = 1000 par = collection_w.partitions if limit > par[1].num_entities: limit_check = par[1].num_entities else: limit_check = limit - range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}} + range_search_params = {"metric_type": "L2", + "params": {"radius": 1000, "range_filter": 0}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, limit, default_search_exp, [par[1].name], _async=_async, @@ -6463,9 +6842,10 @@ class TestCollectionRangeSearch(TestcaseBase): dim=dim, is_index=False, is_flush=is_flush)[ - 0:5] + 0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance @@ -6473,7 +6853,8 @@ class TestCollectionRangeSearch(TestcaseBase): distance_0 = cf.jaccard(query_raw_vector[0], binary_raw_vector[0]) distance_1 = cf.jaccard(query_raw_vector[0], binary_raw_vector[1]) # 4. search and compare the distance - search_params = {"metric_type": "JACCARD", "params": {"radius": 1000, "range_filter": 0}} + search_params = {"metric_type": "JACCARD", + "params": {"radius": 1000, "range_filter": 0}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async, @@ -6485,7 +6866,8 @@ class TestCollectionRangeSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"]) @@ -6501,13 +6883,16 @@ class TestCollectionRangeSearch(TestcaseBase): dim=default_dim, is_index=False,)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance - query_raw_vector, binary_vectors = cf.gen_binary_vectors(3000, default_dim) + query_raw_vector, binary_vectors = cf.gen_binary_vectors( + 3000, default_dim) # 4. range search - search_params = {"metric_type": "JACCARD", "params": {"radius": -1, "range_filter": -10}} + search_params = {"metric_type": "JACCARD", + "params": {"radius": -1, "range_filter": -10}} collection_w.search(binary_vectors[:default_nq], "binary_vector", search_params, default_limit, check_task=CheckTasks.check_search_results, @@ -6540,7 +6925,8 @@ class TestCollectionRangeSearch(TestcaseBase): is_index=False, is_flush=is_flush)[0:4] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "HAMMING"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "HAMMING"} collection_w.create_index("binary_vector", default_index) # 3. compute the distance collection_w.load() @@ -6548,7 +6934,8 @@ class TestCollectionRangeSearch(TestcaseBase): distance_0 = cf.hamming(query_raw_vector[0], binary_raw_vector[0]) distance_1 = cf.hamming(query_raw_vector[0], binary_raw_vector[1]) # 4. search and compare the distance - search_params = {"metric_type": "HAMMING", "params": {"radius": 1000, "range_filter": 0}} + search_params = {"metric_type": "HAMMING", + "params": {"radius": 1000, "range_filter": 0}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async, @@ -6560,7 +6947,8 @@ class TestCollectionRangeSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT"]) @@ -6576,11 +6964,13 @@ class TestCollectionRangeSearch(TestcaseBase): dim=default_dim, is_index=False,)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "HAMMING"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "HAMMING"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance - query_raw_vector, binary_vectors = cf.gen_binary_vectors(3000, default_dim) + query_raw_vector, binary_vectors = cf.gen_binary_vectors( + 3000, default_dim) # 4. range search search_params = {"metric_type": "HAMMING", "params": {"nprobe": 10, "radius": -1, "range_filter": -10}} @@ -6609,7 +6999,8 @@ class TestCollectionRangeSearch(TestcaseBase): is_flush=is_flush)[0:4] log.info("auto_id= %s, _async= %s" % (auto_id, _async)) # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "TANIMOTO"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "TANIMOTO"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance @@ -6645,7 +7036,8 @@ class TestCollectionRangeSearch(TestcaseBase): if _async: res.done() res = res.result() - assert abs(res[0].distances[0] - min(distance_0, distance_1)) <= epsilon + assert abs(res[0].distances[0] - + min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("tanimoto obsolete") @@ -6662,13 +7054,16 @@ class TestCollectionRangeSearch(TestcaseBase): dim=default_dim, is_index=False,)[0:5] # 2. create index - default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "JACCARD"} + default_index = {"index_type": index, "params": { + "nlist": 128}, "metric_type": "JACCARD"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. compute the distance - query_raw_vector, binary_vectors = cf.gen_binary_vectors(3000, default_dim) + query_raw_vector, binary_vectors = cf.gen_binary_vectors( + 3000, default_dim) # 4. range search - search_params = {"metric_type": "JACCARD", "params": {"radius": -1, "range_filter": -10}} + search_params = {"metric_type": "JACCARD", + "params": {"radius": -1, "range_filter": -10}} collection_w.search(binary_vectors[:default_nq], "binary_vector", search_params, default_limit, check_task=CheckTasks.check_search_results, @@ -6684,15 +7079,19 @@ class TestCollectionRangeSearch(TestcaseBase): expected: search successfully with limit(topK) """ # 1. initialize a collection without data - collection_w = self.init_collection_general(prefix, is_binary=True, auto_id=auto_id, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, is_binary=True, auto_id=auto_id, is_index=False)[0] # 2. insert data - insert_ids = cf.insert_data(collection_w, default_nb, is_binary=True, auto_id=auto_id)[3] + insert_ids = cf.insert_data( + collection_w, default_nb, is_binary=True, auto_id=auto_id)[3] # 3. load data - index_params = {"index_type": "BIN_FLAT", "params": {"nlist": 128}, "metric_type": metrics} + index_params = {"index_type": "BIN_FLAT", "params": { + "nlist": 128}, "metric_type": metrics} collection_w.create_index("binary_vector", index_params) collection_w.load() # 4. search - log.info("test_range_search_binary_without_flush: searching collection %s" % collection_w.name) + log.info("test_range_search_binary_without_flush: searching collection %s" % + collection_w.name) binary_vectors = cf.gen_binary_vectors(default_nq, default_dim)[1] search_params = {"metric_type": metrics, "params": {"radius": 1000, "range_filter": 0}} @@ -6717,8 +7116,7 @@ class TestCollectionRangeSearch(TestcaseBase): collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb, dim=dim, is_index=False, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -6735,13 +7133,16 @@ class TestCollectionRangeSearch(TestcaseBase): filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression - log.info("test_range_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + log.info( + "test_range_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] + for _ in range(default_nq)] range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}} search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, @@ -6771,10 +7172,10 @@ class TestCollectionRangeSearch(TestcaseBase): # 1. initialize with data collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id, - enable_dynamic_field= - enable_dynamic_field)[0:4] + enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_range_search_with_output_field: Searching collection %s" % collection_w.name) + log.info("test_range_search_with_output_field: Searching collection %s" % + collection_w.name) range_search_params = {"metric_type": "COSINE", "params": {"radius": 0, "range_filter": 1000}} res = collection_w.search(vectors[:default_nq], default_search_field, @@ -6807,7 +7208,8 @@ class TestCollectionRangeSearch(TestcaseBase): dim=dim)[0:5] def search(collection_w): - vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] + vectors = [[random.random() for _ in range(dim)] + for _ in range(nq)] range_search_params = {"metric_type": "COSINE", "params": {"radius": 0, "range_filter": 1000}} collection_w.search(vectors[:nq], default_search_field, @@ -6820,7 +7222,8 @@ class TestCollectionRangeSearch(TestcaseBase): "_async": _async}) # 2. search with multi-processes - log.info("test_search_concurrent_multi_threads: searching with %s processes" % threads_num) + log.info( + "test_search_concurrent_multi_threads: searching with %s processes" % threads_num) for i in range(threads_num): t = threading.Thread(target=search, args=(collection_w,)) threads.append(t) @@ -6844,7 +7247,8 @@ class TestCollectionRangeSearch(TestcaseBase): # 1. initialize with data collection_w = self.init_collection_general(prefix, True, nb=tmp_nb)[0] # 2. search - log.info("test_search_round_decimal: Searching collection %s" % collection_w.name) + log.info("test_search_round_decimal: Searching collection %s" % + collection_w.name) range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} res = collection_w.search(vectors[:tmp_nq], default_search_field, @@ -6860,7 +7264,8 @@ class TestCollectionRangeSearch(TestcaseBase): dis_actual = res_round[0][i].distance # log.debug(f'actual: {dis_actual}, expect: {dis_expect}') # abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - assert math.isclose(dis_actual, dis_expect, rel_tol=0, abs_tol=abs_tol) + assert math.isclose(dis_actual, dis_expect, + rel_tol=0, abs_tol=abs_tol) @pytest.mark.tags(CaseLabel.L2) def test_range_search_with_expression_large(self, dim): @@ -6876,13 +7281,15 @@ class TestCollectionRangeSearch(TestcaseBase): is_index=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", + "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression expression = f"0 < {default_int64_field_name} < 5001" - log.info("test_search_with_expression: searching with expression: %s" % expression) + log.info( + "test_search_with_expression: searching with expression: %s" % expression) nums = 5000 vectors = [[random.random() for _ in range(dim)] for _ in range(nums)] @@ -6926,7 +7333,8 @@ class TestCollectionRangeSearch(TestcaseBase): }) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_BOUNDED) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_BOUNDED) kwargs.update({"consistency_level": consistency_level}) nb_new = 400 @@ -7002,7 +7410,8 @@ class TestCollectionRangeSearch(TestcaseBase): dim=dim)[0:4] # 2. search for original data after load vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] - range_search_params = {"metric_type": "COSINE", "params": {"nprobe": 10, "radius": 0, "range_filter": 1000}} + range_search_params = {"metric_type": "COSINE", "params": { + "nprobe": 10, "radius": 0, "range_filter": 1000}} collection_w.search(vectors[:nq], default_search_field, range_search_params, limit, default_search_exp, _async=_async, @@ -7017,7 +7426,8 @@ class TestCollectionRangeSearch(TestcaseBase): insert_offset=nb_old) insert_ids.extend(insert_ids_new) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_EVENTUALLY) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_EVENTUALLY) kwargs.update({"consistency_level": consistency_level}) collection_w.search(vectors[:nq], default_search_field, range_search_params, limit, @@ -7053,7 +7463,8 @@ class TestCollectionRangeSearch(TestcaseBase): "_async": _async}) kwargs = {} - consistency_level = kwargs.get("consistency_level", CONSISTENCY_SESSION) + consistency_level = kwargs.get( + "consistency_level", CONSISTENCY_SESSION) kwargs.update({"consistency_level": consistency_level}) nb_new = 400 @@ -7088,7 +7499,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # delete data @@ -7124,7 +7536,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # delete data @@ -7160,7 +7573,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # delete data @@ -7196,7 +7610,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # delete data @@ -7232,7 +7647,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # delete data @@ -7268,7 +7684,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -7280,7 +7697,8 @@ class TestCollectionLoadOperation(TestcaseBase): partition_w1.release() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -7305,7 +7723,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -7342,7 +7761,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -7378,7 +7798,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -7413,7 +7834,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -7425,7 +7847,8 @@ class TestCollectionLoadOperation(TestcaseBase): collection_w.load() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.check_search_results, check_items={"nq": 1, "limit": 100}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -7450,7 +7873,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -7462,7 +7886,8 @@ class TestCollectionLoadOperation(TestcaseBase): collection_w.delete(f"int64 in {delete_ids}") # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -7541,7 +7966,8 @@ class TestCollectionLoadOperation(TestcaseBase): partition_w1.load() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 300, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 300, @@ -7724,7 +8150,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # flush @@ -7759,7 +8186,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # flush @@ -7794,7 +8222,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # flush @@ -7828,7 +8257,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # flush @@ -7839,7 +8269,8 @@ class TestCollectionLoadOperation(TestcaseBase): partition_w2.drop() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not found'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -7864,7 +8295,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # flush @@ -7900,7 +8332,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -7938,7 +8371,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -7949,7 +8383,8 @@ class TestCollectionLoadOperation(TestcaseBase): collection_w.release() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -7974,7 +8409,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load @@ -8009,7 +8445,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -8044,7 +8481,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -8055,7 +8493,8 @@ class TestCollectionLoadOperation(TestcaseBase): collection_w.flush() # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[partition_w1.name, partition_w2.name], + partition_names=[ + partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, @@ -8080,7 +8519,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -8114,7 +8554,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # insert data - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -8146,7 +8587,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # init the collection - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load and release @@ -8176,7 +8618,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # init the collection - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load and release @@ -8200,7 +8643,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # init the collection - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load and release @@ -8221,7 +8665,8 @@ class TestCollectionLoadOperation(TestcaseBase): expected: No exception """ # init the collection - collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load and release @@ -8284,9 +8729,11 @@ class TestCollectionSearchJSON(TestcaseBase): # 1. initialize with data nq = 1 dim = 128 - collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, dim=dim)[0:5] + collection_w, _, _, insert_ids, time_stamp = self.init_collection_general( + prefix, True, dim=dim)[0:5] # 2. search before insert time_stamp - log.info("test_search_json_expression_object: searching collection %s" % collection_w.name) + log.info("test_search_json_expression_object: searching collection %s" % + collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 3. search after insert time_stamp json_search_exp = "json_field > 0" @@ -8332,7 +8779,8 @@ class TestCollectionSearchJSON(TestcaseBase): expected: search successfully """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = [] @@ -8349,8 +8797,10 @@ class TestCollectionSearchJSON(TestcaseBase): # 2. search collection_w.load() - log.info("test_search_with_output_field_json_contains: Searching collection %s" % collection_w.name) - expressions = ["json_contains(json_field['list'], 100)", "JSON_CONTAINS(json_field['list'], 100)"] + log.info("test_search_with_output_field_json_contains: Searching collection %s" % + collection_w.name) + expressions = [ + "json_contains(json_field['list'], 100)", "JSON_CONTAINS(json_field['list'], 100)"] for expression in expressions: collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expression, @@ -8366,7 +8816,8 @@ class TestCollectionSearchJSON(TestcaseBase): expected: search successfully """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, auto_id=auto_id, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general( + prefix, auto_id=auto_id, enable_dynamic_field=True)[0] # 2. insert data limit = 100 @@ -8384,8 +8835,10 @@ class TestCollectionSearchJSON(TestcaseBase): # 2. search collection_w.load() - log.info("test_search_with_output_field_json_contains: Searching collection %s" % collection_w.name) - expressions = ["json_contains(json_field, 100)", "JSON_CONTAINS(json_field, 100)"] + log.info("test_search_with_output_field_json_contains: Searching collection %s" % + collection_w.name) + expressions = [ + "json_contains(json_field, 100)", "JSON_CONTAINS(json_field, 100)"] for expression in expressions: collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, expression, @@ -8401,7 +8854,8 @@ class TestCollectionSearchJSON(TestcaseBase): expected: search successfully """ # 1. initialize with data - collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general( + prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data limit = 100 @@ -8419,7 +8873,8 @@ class TestCollectionSearchJSON(TestcaseBase): # 2. search collection_w.load() - log.info("test_search_with_output_field_json_contains: Searching collection %s" % collection_w.name) + log.info("test_search_with_output_field_json_contains: Searching collection %s" % + collection_w.name) tar = 1000 expressions = [f"json_contains(json_field['list'], '{tar}') && int64 > {tar - limit // 2}", f"JSON_CONTAINS(json_field['list'], '{tar}') && int64 > {tar - limit // 2}"] @@ -8444,7 +8899,8 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data dim = 128 - collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, dim=dim, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator @@ -8468,7 +8924,8 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data batch_size = 200 - collection_w = self.init_collection_general(prefix, True, is_binary=True)[0] + collection_w = self.init_collection_general( + prefix, True, is_binary=True)[0] # 2. search iterator _, binary_vectors = cf.gen_binary_vectors(2, ct.default_dim) collection_w.search_iterator(binary_vectors[:1], binary_field_name, @@ -8488,7 +8945,8 @@ class TestSearchIterator(TestcaseBase): # 1. initialize with data batch_size = 100 dim = 128 - collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, dim=dim, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": metrics}) collection_w.load() # 2. search iterator @@ -8508,7 +8966,8 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator @@ -8528,7 +8987,8 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": metrics}) collection_w.load() # 2. search iterator @@ -8547,7 +9007,8 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator @@ -8571,8 +9032,10 @@ class TestSearchIterator(TestcaseBase): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general(prefix, True, is_index=False)[0] - default_index = {"index_type": index, "params": params, "metric_type": metrics} + collection_w = self.init_collection_general( + prefix, True, is_index=False)[0] + default_index = {"index_type": index, + "params": params, "metric_type": metrics} collection_w.create_index(field_name, default_index) collection_w.load() # 2. search iterator @@ -8609,7 +9072,8 @@ class TestSearchIterator(TestcaseBase): # 1. initialize with data batch_size = 100 dim = 128 - collection_w = self.init_collection_general(prefix, True, dim=dim, is_index=False)[0] + collection_w = self.init_collection_general( + prefix, True, dim=dim, is_index=False)[0] collection_w.create_index(field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index e128030280..889b1ef20b 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -551,9 +551,8 @@ class TestUtilityParams(TestcaseBase): new_collection_name = cf.gen_unique_str(prefix) self.utility_wrap.rename_collection(old_collection_name, new_collection_name, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "collection {} was not " - "loaded into memory)".format(collection_w.name)}) + check_items={"err_code": 4, + "err_msg": "collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_rename_collection_new_invalid_type(self, get_invalid_type_collection_name):