feat: [Sparse Float Vector] added some integration tests (#31062)

add some integration tests for sparse float vector support

https://github.com/milvus-io/milvus/issues/29419

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
Buqian Zheng 2024-04-10 19:57:18 +08:00 committed by GitHub
parent 25a1c9ecf0
commit 2fdf1a6e76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 809 additions and 47 deletions

View File

@ -1160,7 +1160,7 @@ func fillFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgstr
} }
if len(insertMsg.FieldsData) != requiredFieldsNum { if len(insertMsg.FieldsData) != requiredFieldsNum {
log.Warn("the number of fields is less than needed", log.Warn("the number of fields is not the same as needed",
zap.Int("fieldNum", len(insertMsg.FieldsData)), zap.Int("fieldNum", len(insertMsg.FieldsData)),
zap.Int("requiredFieldNum", requiredFieldsNum), zap.Int("requiredFieldNum", requiredFieldsNum),
zap.String("collection", schema.GetName())) zap.String("collection", schema.GetName()))

View File

@ -108,6 +108,8 @@ const (
DimKey = "dim" DimKey = "dim"
MaxLengthKey = "max_length" MaxLengthKey = "max_length"
MaxCapacityKey = "max_capacity" MaxCapacityKey = "max_capacity"
DropRatioBuildKey = "drop_ratio_build"
) )
// Collection properties key // Collection properties key

View File

@ -86,19 +86,23 @@ func (s *TestGetVectorSuite) run() {
IndexParams: nil, IndexParams: nil,
AutoID: false, AutoID: false,
} }
typeParams := []*commonpb.KeyValuePair{}
if !typeutil.IsSparseFloatVectorType(s.vecType) {
typeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", dim),
},
}
}
fVec := &schemapb.FieldSchema{ fVec := &schemapb.FieldSchema{
FieldID: 101, FieldID: 101,
Name: vecFieldName, Name: vecFieldName,
IsPrimaryKey: false, IsPrimaryKey: false,
Description: "", Description: "",
DataType: s.vecType, DataType: s.vecType,
TypeParams: []*commonpb.KeyValuePair{ TypeParams: typeParams,
{ IndexParams: nil,
Key: common.DimKey,
Value: fmt.Sprintf("%d", dim),
},
},
IndexParams: nil,
} }
schema := integration.ConstructSchema(collection, dim, false, pk, fVec) schema := integration.ConstructSchema(collection, dim, false, pk, fVec)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
@ -126,6 +130,8 @@ func (s *TestGetVectorSuite) run() {
vecFieldData = integration.NewFloat16VectorFieldData(vecFieldName, NB, dim) vecFieldData = integration.NewFloat16VectorFieldData(vecFieldName, NB, dim)
// } else if s.vecType == schemapb.DataType_BFloat16Vector { // } else if s.vecType == schemapb.DataType_BFloat16Vector {
// vecFieldData = integration.NewBFloat16VectorFieldData(vecFieldName, NB, dim) // vecFieldData = integration.NewBFloat16VectorFieldData(vecFieldName, NB, dim)
} else if typeutil.IsSparseFloatVectorType(s.vecType) {
vecFieldData = integration.NewSparseFloatVectorFieldData(vecFieldName, NB)
} else { } else {
vecFieldData = integration.NewBinaryVectorFieldData(vecFieldName, NB, dim) vecFieldData = integration.NewBinaryVectorFieldData(vecFieldName, NB, dim)
} }
@ -193,7 +199,7 @@ func (s *TestGetVectorSuite) run() {
searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq) searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) s.Require().Equal(commonpb.ErrorCode_Success, searchResp.GetStatus().GetErrorCode())
result := searchResp.GetResults() result := searchResp.GetResults()
if s.pkType == schemapb.DataType_Int64 { if s.pkType == schemapb.DataType_Int64 {
@ -253,6 +259,21 @@ func (s *TestGetVectorSuite) run() {
// } // }
// } // }
} else if s.vecType == schemapb.DataType_BFloat16Vector { } else if s.vecType == schemapb.DataType_BFloat16Vector {
} else if s.vecType == schemapb.DataType_SparseFloatVector {
s.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetSparseFloatVector().GetContents(), nq*topk)
rawData := vecFieldData.GetVectors().GetSparseFloatVector().GetContents()
resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetSparseFloatVector().GetContents()
if s.pkType == schemapb.DataType_Int64 {
for i, id := range result.GetIds().GetIntId().GetData() {
s.Require().Equal(rawData[id], resData[i])
}
} else {
for i, idStr := range result.GetIds().GetStrId().GetData() {
id, err := strconv.Atoi(idStr)
s.Require().NoError(err)
s.Require().Equal(rawData[id], resData[i])
}
}
} else { } else {
s.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8) s.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8)
rawData := vecFieldData.GetVectors().GetBinaryVector() rawData := vecFieldData.GetVectors().GetBinaryVector()
@ -430,6 +451,46 @@ func (s *TestGetVectorSuite) TestGetVector_With_DB_Name() {
s.run() s.run()
} }
func (s *TestGetVectorSuite) TestGetVector_Sparse_SPARSE_INVERTED_INDEX() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexSparseInvertedIndex
s.metricType = metric.IP
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
func (s *TestGetVectorSuite) TestGetVector_Sparse_SPARSE_INVERTED_INDEX_StrPK() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexSparseInvertedIndex
s.metricType = metric.IP
s.pkType = schemapb.DataType_VarChar
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
func (s *TestGetVectorSuite) TestGetVector_Sparse_SPARSE_WAND() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexSparseWand
s.metricType = metric.IP
s.pkType = schemapb.DataType_Int64
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
func (s *TestGetVectorSuite) TestGetVector_Sparse_SPARSE_WAND_StrPK() {
s.nq = 10
s.topK = 10
s.indexType = integration.IndexSparseWand
s.metricType = metric.IP
s.pkType = schemapb.DataType_VarChar
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
//func (s *TestGetVectorSuite) TestGetVector_DISKANN_L2() { //func (s *TestGetVectorSuite) TestGetVector_DISKANN_L2() {
// s.nq = 10 // s.nq = 10
// s.topK = 10 // s.topK = 10

View File

@ -40,9 +40,13 @@ import (
type HelloMilvusSuite struct { type HelloMilvusSuite struct {
integration.MiniClusterSuite integration.MiniClusterSuite
indexType string
metricType string
vecType schemapb.DataType
} }
func (s *HelloMilvusSuite) TestHelloMilvus() { func (s *HelloMilvusSuite) run() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
c := s.Cluster c := s.Cluster
@ -55,7 +59,7 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
collectionName := "TestHelloMilvus" + funcutil.GenRandomStr() collectionName := "TestHelloMilvus" + funcutil.GenRandomStr()
schema := integration.ConstructSchema(collectionName, dim, true) schema := integration.ConstructSchemaOfVecDataType(collectionName, dim, true, s.vecType)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
s.NoError(err) s.NoError(err)
@ -77,7 +81,12 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) var fVecColumn *schemapb.FieldData
if s.vecType == schemapb.DataType_SparseFloatVector {
fVecColumn = integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
} else {
fVecColumn = integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
}
hashKeys := integration.GenerateHashKeys(rowNum) hashKeys := integration.GenerateHashKeys(rowNum)
insertCheckReport := func() { insertCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
@ -131,9 +140,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
// create index // create index
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName, CollectionName: collectionName,
FieldName: integration.FloatVecField, FieldName: fVecColumn.FieldName,
IndexName: "_default", IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2), ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType),
}) })
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success { if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason())) log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
@ -141,7 +150,7 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
s.NoError(err) s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode()) s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName)
// load // load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
@ -161,9 +170,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
topk := 10 topk := 10
roundDecimal := -1 roundDecimal := -1
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2) params := integration.GetSearchParams(s.indexType, s.metricType)
searchReq := integration.ConstructSearchRequest("", collectionName, expr, searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal) fVecColumn.FieldName, s.vecType, nil, s.metricType, params, nq, dim, topk, roundDecimal)
searchCheckReport := func() { searchCheckReport := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
@ -266,6 +275,27 @@ func (s *HelloMilvusSuite) TestHelloMilvus() {
log.Info("TestHelloMilvus succeed") log.Info("TestHelloMilvus succeed")
} }
func (s *HelloMilvusSuite) TestHelloMilvus_basic() {
s.indexType = integration.IndexFaissIvfFlat
s.metricType = metric.L2
s.vecType = schemapb.DataType_FloatVector
s.run()
}
func (s *HelloMilvusSuite) TestHelloMilvus_sparse_basic() {
s.indexType = integration.IndexSparseInvertedIndex
s.metricType = metric.IP
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
func (s *HelloMilvusSuite) TestHelloMilvus_sparse_wand_basic() {
s.indexType = integration.IndexSparseWand
s.metricType = metric.IP
s.vecType = schemapb.DataType_SparseFloatVector
s.run()
}
func TestHelloMilvus(t *testing.T) { func TestHelloMilvus(t *testing.T) {
suite.Run(t, new(HelloMilvusSuite)) suite.Run(t, new(HelloMilvusSuite))
} }

View File

@ -42,6 +42,7 @@ func (s *HybridSearchSuite) TestHybridSearch() {
&schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, &schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
&schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, &schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
&schemapb.FieldSchema{Name: integration.BinVecField, DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, &schemapb.FieldSchema{Name: integration.BinVecField, DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
&schemapb.FieldSchema{Name: integration.SparseFloatVecField, DataType: schemapb.DataType_SparseFloatVector},
) )
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
s.NoError(err) s.NoError(err)
@ -67,11 +68,12 @@ func (s *HybridSearchSuite) TestHybridSearch() {
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
bVecColumn := integration.NewBinaryVectorFieldData(integration.BinVecField, rowNum, dim) bVecColumn := integration.NewBinaryVectorFieldData(integration.BinVecField, rowNum, dim)
sparseVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum) hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName, DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn}, FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, sparseVecColumn},
HashKeys: hashKeys, HashKeys: hashKeys,
NumRows: uint32(rowNum), NumRows: uint32(rowNum),
}) })
@ -143,6 +145,28 @@ func (s *HybridSearchSuite) TestHybridSearch() {
} }
s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.BinVecField, "_default_binary") s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.BinVecField, "_default_binary")
// load with index on partial vector fields
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
s.Error(merr.Error(loadStatus))
// create index for sparse float vector
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default_sparse",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexSparseInvertedIndex, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
if err != nil {
log.Warn("createIndexStatus fail reason", zap.Error(err))
}
s.WaitForIndexBuiltWithIndexName(ctx, collectionName, integration.SparseFloatVecField, "_default_sparse")
// load with index on all vector fields // load with index on all vector fields
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName, DbName: dbName,
@ -163,18 +187,21 @@ func (s *HybridSearchSuite) TestHybridSearch() {
fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2) fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
bParams := integration.GetSearchParams(integration.IndexFaissBinIvfFlat, metric.L2) bParams := integration.GetSearchParams(integration.IndexFaissBinIvfFlat, metric.L2)
sParams := integration.GetSearchParams(integration.IndexSparseInvertedIndex, metric.IP)
fSearchReq := integration.ConstructSearchRequest("", collectionName, expr, fSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal) integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal)
bSearchReq := integration.ConstructSearchRequest("", collectionName, expr, bSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.BinVecField, schemapb.DataType_BinaryVector, nil, metric.JACCARD, bParams, nq, dim, topk, roundDecimal) integration.BinVecField, schemapb.DataType_BinaryVector, nil, metric.JACCARD, bParams, nq, dim, topk, roundDecimal)
sSearchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.SparseFloatVecField, schemapb.DataType_SparseFloatVector, nil, metric.IP, sParams, nq, dim, topk, roundDecimal)
hSearchReq := &milvuspb.HybridSearchRequest{ hSearchReq := &milvuspb.HybridSearchRequest{
Base: nil, Base: nil,
DbName: dbName, DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: nil, PartitionNames: nil,
Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq}, Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq, sSearchReq},
OutputFields: []string{integration.FloatVecField, integration.BinVecField}, OutputFields: []string{integration.FloatVecField, integration.BinVecField},
} }
@ -196,7 +223,7 @@ func (s *HybridSearchSuite) TestHybridSearch() {
// weighted rank hybrid search // weighted rank hybrid search
weightsParams := make(map[string][]float64) weightsParams := make(map[string][]float64)
weightsParams[proxy.WeightsParamsKey] = []float64{0.5, 0.2} weightsParams[proxy.WeightsParamsKey] = []float64{0.5, 0.2, 0.1}
b, err = json.Marshal(weightsParams) b, err = json.Marshal(weightsParams)
s.NoError(err) s.NoError(err)
@ -206,8 +233,8 @@ func (s *HybridSearchSuite) TestHybridSearch() {
DbName: dbName, DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: nil, PartitionNames: nil,
Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq}, Requests: []*milvuspb.SearchRequest{fSearchReq, bSearchReq, sSearchReq},
OutputFields: []string{integration.FloatVecField, integration.BinVecField}, OutputFields: []string{integration.FloatVecField, integration.BinVecField, integration.SparseFloatVecField},
} }
hSearchReq.RankParams = []*commonpb.KeyValuePair{ hSearchReq.RankParams = []*commonpb.KeyValuePair{
{Key: proxy.RankTypeKey, Value: "weighted"}, {Key: proxy.RankTypeKey, Value: "weighted"},

View File

@ -19,9 +19,13 @@ import (
type GetIndexStatisticsSuite struct { type GetIndexStatisticsSuite struct {
integration.MiniClusterSuite integration.MiniClusterSuite
indexType string
metricType string
vecType schemapb.DataType
} }
func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() { func (s *GetIndexStatisticsSuite) run() {
c := s.Cluster c := s.Cluster
ctx, cancel := context.WithCancel(c.GetContext()) ctx, cancel := context.WithCancel(c.GetContext())
defer cancel() defer cancel()
@ -153,6 +157,13 @@ func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() {
log.Info("TestGetIndexStatistics succeed") log.Info("TestGetIndexStatistics succeed")
} }
func (s *GetIndexStatisticsSuite) TestGetIndexStatistics_float() {
s.indexType = integration.IndexFaissIvfFlat
s.metricType = metric.L2
s.vecType = schemapb.DataType_FloatVector
s.run()
}
func TestGetIndexStat(t *testing.T) { func TestGetIndexStat(t *testing.T) {
suite.Run(t, new(GetIndexStatisticsSuite)) suite.Run(t, new(GetIndexStatisticsSuite))
} }

View File

@ -38,6 +38,7 @@ type InsertSuite struct {
integration.MiniClusterSuite integration.MiniClusterSuite
} }
// insert request with duplicate field data should fail
func (s *InsertSuite) TestInsert() { func (s *InsertSuite) TestInsert() {
c := s.Cluster c := s.Cluster
ctx, cancel := context.WithCancel(c.GetContext()) ctx, cancel := context.WithCancel(c.GetContext())

View File

@ -0,0 +1,549 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sparse_test
import (
"context"
"encoding/binary"
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/testutils"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/tests/integration"
)
type SparseTestSuite struct {
integration.MiniClusterSuite
}
func (s *SparseTestSuite) createCollection(ctx context.Context, c *integration.MiniClusterV2, dbName string) string {
collectionName := "TestSparse" + funcutil.GenRandomStr()
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: integration.Int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: integration.SparseFloatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: nil,
IndexParams: nil,
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
s.Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
return collectionName
}
func (s *SparseTestSuite) TestSparse_should_not_speficy_dim() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := "TestSparse" + funcutil.GenRandomStr()
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: integration.Int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: integration.SparseFloatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", 10),
},
},
IndexParams: nil,
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
s.NotEqual(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
}
func (s *SparseTestSuite) TestSparse_invalid_insert() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
sparseVecs := fVecColumn.Field.(*schemapb.FieldData_Vectors).Vectors.GetSparseFloatVector()
// negative column index is not allowed
oldIdx := typeutil.SparseFloatRowIndexAt(sparseVecs.Contents[0], 0)
var newIdx int32 = -10
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], uint32(newIdx))
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], oldIdx)
// of each row, length of indices and data must equal
sparseVecs.Contents[0] = append(sparseVecs.Contents[0], make([]byte, 4)...)
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
sparseVecs.Contents[0] = sparseVecs.Contents[0][:len(sparseVecs.Contents[0])-4]
// empty row is not allowed
sparseVecs.Contents[0] = []byte{}
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// unsorted column index is not allowed
sparseVecs.Contents[0] = make([]byte, 16)
testutils.SparseFloatRowSetAt(sparseVecs.Contents[0], 0, 20, 0.1)
testutils.SparseFloatRowSetAt(sparseVecs.Contents[0], 1, 10, 0.2)
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.NotEqual(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
}
func (s *SparseTestSuite) TestSparse_invalid_index_build() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
// unsupported index type
indexParams := []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexFaissIvfPQ,
},
{
Key: common.MetricTypeKey,
Value: metric.IP,
},
}
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// nonexist index
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "INDEX_WHAT",
},
{
Key: common.MetricTypeKey,
Value: metric.IP,
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect metric type
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect drop ratio build
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: common.DropRatioBuildKey,
Value: "-0.1",
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
// incorrect drop ratio build
indexParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: integration.IndexSparseInvertedIndex,
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: common.DropRatioBuildKey,
Value: "1.1",
},
}
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
}
func (s *SparseTestSuite) TestSparse_invalid_search_request() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := s.Cluster
const (
dbName = ""
rowNum = 3000
)
collectionName := s.createCollection(ctx, c, dbName)
// valid insert
fVecColumn := integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
// flush
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
indexType := integration.IndexSparseInvertedIndex
metricType := metric.IP
indexParams := []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: metricType,
},
{
Key: common.IndexTypeKey,
Value: indexType,
},
{
Key: common.DropRatioBuildKey,
Value: "0.1",
},
}
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.SparseFloatVecField,
IndexName: "_default",
ExtraParams: indexParams,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
s.WaitForIndexBuilt(ctx, collectionName, integration.SparseFloatVecField)
// load
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.WaitForLoad(ctx, collectionName)
// search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
params := integration.GetSearchParams(indexType, metricType)
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.SparseFloatVecField, schemapb.DataType_SparseFloatVector, nil, metricType, params, nq, 0, topk, roundDecimal)
replaceQuery := func(vecs *schemapb.SparseFloatArray) {
values := make([][]byte, 0, 1)
bs, err := proto.Marshal(vecs)
if err != nil {
panic(err)
}
values = append(values, bs)
plg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_SparseFloatVector,
Values: values,
},
},
}
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
searchReq.PlaceholderGroup = plgBs
}
sparseVecs := integration.GenerateSparseFloatArray(nq)
// negative column index
oldIdx := typeutil.SparseFloatRowIndexAt(sparseVecs.Contents[0], 0)
var newIdx int32 = -10
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], uint32(newIdx))
replaceQuery(sparseVecs)
searchResult, err := c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
binary.LittleEndian.PutUint32(sparseVecs.Contents[0][0:], oldIdx)
// of each row, length of indices and data must equal
sparseVecs.Contents[0] = append(sparseVecs.Contents[0], make([]byte, 4)...)
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
sparseVecs.Contents[0] = sparseVecs.Contents[0][:len(sparseVecs.Contents[0])-4]
// empty row is not allowed
sparseVecs.Contents[0] = []byte{}
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
// column index in the same row must be ordered
sparseVecs.Contents[0] = make([]byte, 16)
testutils.SparseFloatRowSetAt(sparseVecs.Contents[0], 0, 20, 0.1)
testutils.SparseFloatRowSetAt(sparseVecs.Contents[0], 1, 10, 0.2)
replaceQuery(sparseVecs)
searchResult, err = c.Proxy.Search(ctx, searchReq)
s.NoError(err)
s.NotEqual(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode())
}
func TestSparse(t *testing.T) {
suite.Run(t, new(SparseTestSuite))
}

View File

@ -30,17 +30,19 @@ import (
) )
const ( const (
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
IndexScaNN = indexparamcheck.IndexScaNN IndexScaNN = indexparamcheck.IndexScaNN
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8 IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
IndexHNSW = indexparamcheck.IndexHNSW IndexHNSW = indexparamcheck.IndexHNSW
IndexDISKANN = indexparamcheck.IndexDISKANN IndexDISKANN = indexparamcheck.IndexDISKANN
IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted
IndexSparseWand = indexparamcheck.IndexSparseWand
) )
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) { func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
@ -166,6 +168,8 @@ func ConstructIndexParam(dim int, indexType string, metricType string) []*common
Key: "efConstruction", Key: "efConstruction",
Value: "200", Value: "200",
}) })
case IndexSparseInvertedIndex:
case IndexSparseWand:
case IndexDISKANN: case IndexDISKANN:
default: default:
panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType)) panic(fmt.Sprintf("unimplemented index param for %s, please help to improve it", indexType))
@ -184,6 +188,9 @@ func GetSearchParams(indexType string, metricType string) map[string]any {
params["ef"] = 200 params["ef"] = 200
case IndexDISKANN: case IndexDISKANN:
params["search_list"] = 20 params["search_list"] = 20
case IndexSparseInvertedIndex:
case IndexSparseWand:
params["drop_ratio_search"] = 0.1
default: default:
panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType)) panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType))
} }

View File

@ -24,6 +24,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/testutils"
) )
func (s *MiniClusterSuite) WaitForFlush(ctx context.Context, segIDs []int64, flushTs uint64, dbName, collectionName string) { func (s *MiniClusterSuite) WaitForFlush(ctx context.Context, segIDs []int64, flushTs uint64, dbName, collectionName string) {
@ -176,6 +177,22 @@ func NewBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.Fiel
} }
} }
func NewSparseFloatVectorFieldData(fieldName string, numRows int) *schemapb.FieldData {
sparseVecs := GenerateSparseFloatArray(numRows)
return &schemapb.FieldData{
Type: schemapb.DataType_SparseFloatVector,
FieldName: fieldName,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: sparseVecs.Dim,
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: sparseVecs,
},
},
},
}
}
func GenerateInt64Array(numRows int, start int64) []int64 { func GenerateInt64Array(numRows int, start int64) []int64 {
ret := make([]int64, numRows) ret := make([]int64, numRows)
for i := 0; i < numRows; i++ { for i := 0; i < numRows; i++ {
@ -229,6 +246,10 @@ func GenerateFloat16Vectors(numRows, dim int) []byte {
return ret return ret
} }
func GenerateSparseFloatArray(numRows int) *schemapb.SparseFloatArray {
return testutils.GenerateSparseFloatVectors(numRows)
}
// func GenerateBFloat16Vectors(numRows, dim int) []byte { // func GenerateBFloat16Vectors(numRows, dim int) []byte {
// total := numRows * dim * 2 // total := numRows * dim * 2
// ret := make([]byte, total) // ret := make([]byte, total)

View File

@ -128,6 +128,7 @@ func ConstructSearchRequest(
}, },
TravelTimestamp: 0, TravelTimestamp: 0,
GuaranteeTimestamp: 0, GuaranteeTimestamp: 0,
Nq: int64(nq),
} }
} }
@ -243,6 +244,13 @@ func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commo
// } // }
// values = append(values, ret) // values = append(values, ret)
// } // }
case schemapb.DataType_SparseFloatVector:
// for sparse, all query rows are encoded in a single byte array
values = make([][]byte, 0, 1)
placeholderType = commonpb.PlaceholderType_SparseFloatVector
sparseVecs := GenerateSparseFloatArray(nq)
values = append(values, sparseVecs.Contents...)
default: default:
panic("invalid vector data type") panic("invalid vector data type")
} }

View File

@ -25,19 +25,20 @@ import (
) )
const ( const (
BoolField = "boolField" BoolField = "boolField"
Int8Field = "int8Field" Int8Field = "int8Field"
Int16Field = "int16Field" Int16Field = "int16Field"
Int32Field = "int32Field" Int32Field = "int32Field"
Int64Field = "int64Field" Int64Field = "int64Field"
FloatField = "floatField" FloatField = "floatField"
DoubleField = "doubleField" DoubleField = "doubleField"
VarCharField = "varCharField" VarCharField = "varCharField"
JSONField = "jsonField" JSONField = "jsonField"
FloatVecField = "floatVecField" FloatVecField = "floatVecField"
BinVecField = "binVecField" BinVecField = "binVecField"
Float16VecField = "float16VecField" Float16VecField = "float16VecField"
BFloat16VecField = "bfloat16VecField" BFloat16VecField = "bfloat16VecField"
SparseFloatVecField = "sparseFloatVecField"
) )
func ConstructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema { func ConstructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema {
@ -81,3 +82,47 @@ func ConstructSchema(collection string, dim int, autoID bool, fields ...*schemap
Fields: []*schemapb.FieldSchema{pk, fVec}, Fields: []*schemapb.FieldSchema{pk, fVec},
} }
} }
func ConstructSchemaOfVecDataType(collection string, dim int, autoID bool, dataType schemapb.DataType) *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: Int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: autoID,
}
var name string
var typeParams []*commonpb.KeyValuePair
switch dataType {
case schemapb.DataType_FloatVector:
name = FloatVecField
typeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: fmt.Sprintf("%d", dim),
},
}
case schemapb.DataType_SparseFloatVector:
name = SparseFloatVecField
typeParams = nil
default:
panic("unsupported data type")
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: name,
IsPrimaryKey: false,
Description: "",
DataType: dataType,
TypeParams: typeParams,
IndexParams: nil,
}
return &schemapb.CollectionSchema{
Name: collection,
AutoID: autoID,
Fields: []*schemapb.FieldSchema{pk, fVec},
}
}