mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
feat: Restful support for BM25 function (#36713)
issue: https://github.com/milvus-io/milvus/issues/35853 Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
parent
e170991a10
commit
16b533cbf0
@ -95,14 +95,22 @@ const (
|
|||||||
|
|
||||||
HTTPReturnHas = "has"
|
HTTPReturnHas = "has"
|
||||||
|
|
||||||
HTTPReturnFieldName = "name"
|
HTTPReturnFieldName = "name"
|
||||||
HTTPReturnFieldID = "id"
|
HTTPReturnFieldID = "id"
|
||||||
HTTPReturnFieldType = "type"
|
HTTPReturnFieldType = "type"
|
||||||
HTTPReturnFieldPrimaryKey = "primaryKey"
|
HTTPReturnFieldPrimaryKey = "primaryKey"
|
||||||
HTTPReturnFieldPartitionKey = "partitionKey"
|
HTTPReturnFieldPartitionKey = "partitionKey"
|
||||||
HTTPReturnFieldAutoID = "autoId"
|
HTTPReturnFieldAutoID = "autoId"
|
||||||
HTTPReturnFieldElementType = "elementType"
|
HTTPReturnFieldElementType = "elementType"
|
||||||
HTTPReturnDescription = "description"
|
HTTPReturnDescription = "description"
|
||||||
|
HTTPReturnFieldIsFunctionOutput = "isFunctionOutput"
|
||||||
|
|
||||||
|
HTTPReturnFunctionName = "name"
|
||||||
|
HTTPReturnFunctionID = "id"
|
||||||
|
HTTPReturnFunctionType = "type"
|
||||||
|
HTTPReturnFunctionInputFieldNames = "inputFieldNames"
|
||||||
|
HTTPReturnFunctionOutputFieldNames = "outputFieldNames"
|
||||||
|
HTTPReturnFunctionParams = "params"
|
||||||
|
|
||||||
HTTPReturnIndexMetricType = "metricType"
|
HTTPReturnIndexMetricType = "metricType"
|
||||||
HTTPReturnIndexType = "indexType"
|
HTTPReturnIndexType = "indexType"
|
||||||
|
|||||||
@ -437,6 +437,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
|
|||||||
HTTPReturnDescription: coll.Schema.Description,
|
HTTPReturnDescription: coll.Schema.Description,
|
||||||
HTTPReturnFieldAutoID: autoID,
|
HTTPReturnFieldAutoID: autoID,
|
||||||
"fields": printFieldsV2(coll.Schema.Fields),
|
"fields": printFieldsV2(coll.Schema.Fields),
|
||||||
|
"functions": printFunctionDetails(coll.Schema.Functions),
|
||||||
"aliases": aliases,
|
"aliases": aliases,
|
||||||
"indexes": indexDesc,
|
"indexes": indexDesc,
|
||||||
"load": collLoadState,
|
"load": collLoadState,
|
||||||
@ -897,7 +898,21 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
|||||||
if !typeutil.IsSparseFloatVectorType(vectorField.DataType) {
|
if !typeutil.IsSparseFloatVectorType(vectorField.DataType) {
|
||||||
dim, _ = getDim(vectorField)
|
dim, _ = getDim(vectorField)
|
||||||
}
|
}
|
||||||
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
|
|
||||||
|
dataType := vectorField.DataType
|
||||||
|
|
||||||
|
if vectorField.GetIsFunctionOutput() {
|
||||||
|
for _, function := range collSchema.Functions {
|
||||||
|
if function.Type == schemapb.FunctionType_BM25 {
|
||||||
|
// TODO: currently only BM25 function is supported, thus guarantees one input field to one output field
|
||||||
|
if function.OutputFieldNames[0] == vectorField.Name {
|
||||||
|
dataType = schemapb.DataType_VarChar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
phv, err := convertQueries2Placeholder(body, dataType, dim)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1086,6 +1101,17 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||||||
fieldNames := map[string]bool{}
|
fieldNames := map[string]bool{}
|
||||||
partitionsNum := int64(-1)
|
partitionsNum := int64(-1)
|
||||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||||
|
if len(httpReq.Schema.Functions) > 0 {
|
||||||
|
err := merr.WrapErrParameterInvalid("schema", "functions",
|
||||||
|
"functions are not supported for quickly create collection")
|
||||||
|
log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq))
|
||||||
|
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||||
|
HTTPReturnCode: merr.Code(err),
|
||||||
|
HTTPReturnMessage: err.Error(),
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if httpReq.Dimension == 0 {
|
if httpReq.Dimension == 0 {
|
||||||
err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName",
|
err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName",
|
||||||
"dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")")
|
"dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")")
|
||||||
@ -1162,8 +1188,40 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||||||
Name: httpReq.CollectionName,
|
Name: httpReq.CollectionName,
|
||||||
AutoID: httpReq.Schema.AutoId,
|
AutoID: httpReq.Schema.AutoId,
|
||||||
Fields: []*schemapb.FieldSchema{},
|
Fields: []*schemapb.FieldSchema{},
|
||||||
|
Functions: []*schemapb.FunctionSchema{},
|
||||||
EnableDynamicField: httpReq.Schema.EnableDynamicField,
|
EnableDynamicField: httpReq.Schema.EnableDynamicField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allOutputFields := []string{}
|
||||||
|
|
||||||
|
for _, function := range httpReq.Schema.Functions {
|
||||||
|
functionTypeValue, ok := schemapb.FunctionType_value[function.FunctionType]
|
||||||
|
if !ok {
|
||||||
|
log.Ctx(ctx).Warn("function's data type is invalid(case sensitive).", zap.Any("function.DataType", function.FunctionType), zap.Any("function", function))
|
||||||
|
err := merr.WrapErrParameterInvalid("FunctionType", function.FunctionType, "function data type is invalid(case sensitive)")
|
||||||
|
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||||
|
HTTPReturnCode: merr.Code(merr.ErrParameterInvalid),
|
||||||
|
HTTPReturnMessage: err.Error(),
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
functionType := schemapb.FunctionType(functionTypeValue)
|
||||||
|
description := function.Description
|
||||||
|
params := []*commonpb.KeyValuePair{}
|
||||||
|
for key, value := range function.Params {
|
||||||
|
params = append(params, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||||
|
}
|
||||||
|
collSchema.Functions = append(collSchema.Functions, &schemapb.FunctionSchema{
|
||||||
|
Name: function.FunctionName,
|
||||||
|
Description: description,
|
||||||
|
Type: functionType,
|
||||||
|
InputFieldNames: function.InputFieldNames,
|
||||||
|
OutputFieldNames: function.OutputFieldNames,
|
||||||
|
Params: params,
|
||||||
|
})
|
||||||
|
allOutputFields = append(allOutputFields, function.OutputFieldNames...)
|
||||||
|
}
|
||||||
|
|
||||||
for _, field := range httpReq.Schema.Fields {
|
for _, field := range httpReq.Schema.Fields {
|
||||||
fieldDataType, ok := schemapb.DataType_value[field.DataType]
|
fieldDataType, ok := schemapb.DataType_value[field.DataType]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -1218,6 +1276,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||||||
for key, fieldParam := range field.ElementTypeParams {
|
for key, fieldParam := range field.ElementTypeParams {
|
||||||
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", fieldParam)})
|
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", fieldParam)})
|
||||||
}
|
}
|
||||||
|
if lo.Contains(allOutputFields, field.FieldName) {
|
||||||
|
fieldSchema.IsFunctionOutput = true
|
||||||
|
}
|
||||||
collSchema.Fields = append(collSchema.Fields, &fieldSchema)
|
collSchema.Fields = append(collSchema.Fields, &fieldSchema)
|
||||||
fieldNames[field.FieldName] = true
|
fieldNames[field.FieldName] = true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,6 +57,22 @@ func init() {
|
|||||||
paramtable.Init()
|
paramtable.Init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sendReqAndVerify(t *testing.T, testEngine *gin.Engine, testName, method string, testcase requestBodyTestCase) {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(method, testcase.path, bytes.NewReader(testcase.requestBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
testEngine.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
returnBody := &ReturnErrMsg{}
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), returnBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testcase.errCode, returnBody.Code)
|
||||||
|
if testcase.errCode != 0 {
|
||||||
|
assert.Contains(t, returnBody.Message, testcase.errMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPWrapper(t *testing.T) {
|
func TestHTTPWrapper(t *testing.T) {
|
||||||
postTestCases := []requestBodyTestCase{}
|
postTestCases := []requestBodyTestCase{}
|
||||||
postTestCasesTrace := []requestBodyTestCase{}
|
postTestCasesTrace := []requestBodyTestCase{}
|
||||||
@ -468,6 +484,230 @@ func TestDatabaseWrapper(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutCreateCollection(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
// disable rate limit
|
||||||
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||||
|
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||||
|
|
||||||
|
postTestCases := []requestBodyTestCase{}
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(1)
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
path := versionalV2(CollectionCategory, CreateAction)
|
||||||
|
|
||||||
|
const baseRequestBody = `{
|
||||||
|
"collectionName": "doc_in_doc_out_demo",
|
||||||
|
"schema": {
|
||||||
|
"autoId": false,
|
||||||
|
"enableDynamicField": false,
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"fieldName": "my_id",
|
||||||
|
"dataType": "Int64",
|
||||||
|
"isPrimary": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fieldName": "document_content",
|
||||||
|
"dataType": "VarChar",
|
||||||
|
"elementTypeParams": {
|
||||||
|
"max_length": "9000"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fieldName": "sparse_vector_1",
|
||||||
|
"dataType": "SparseFloatVector"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"functions": %s
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||||
|
path: path,
|
||||||
|
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||||
|
{
|
||||||
|
"name": "bm25_fn_1",
|
||||||
|
"type": "BM25",
|
||||||
|
"inputFieldNames": ["document_content"],
|
||||||
|
"outputFieldNames": ["sparse_vector_1"]
|
||||||
|
}
|
||||||
|
]`)),
|
||||||
|
})
|
||||||
|
|
||||||
|
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||||
|
path: path,
|
||||||
|
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||||
|
{
|
||||||
|
"name": "bm25_fn_1",
|
||||||
|
"type": "BM25_",
|
||||||
|
"inputFieldNames": ["document_content"],
|
||||||
|
"outputFieldNames": ["sparse_vector_1"]
|
||||||
|
}
|
||||||
|
]`)),
|
||||||
|
errMsg: "actual=BM25_",
|
||||||
|
errCode: 1100,
|
||||||
|
})
|
||||||
|
|
||||||
|
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||||
|
path: path,
|
||||||
|
requestBody: []byte(fmt.Sprintf(baseRequestBody, `[
|
||||||
|
{
|
||||||
|
"name": "bm25_fn_1",
|
||||||
|
"inputFieldNames": ["document_content"],
|
||||||
|
"outputFieldNames": ["sparse_vector_1"]
|
||||||
|
}
|
||||||
|
]`)),
|
||||||
|
errMsg: "actual=", // unprovided function type is empty string
|
||||||
|
errCode: 1100,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, testcase := range postTestCases {
|
||||||
|
sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutCreateCollectionQuickDisallowFunction(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
// disable rate limit
|
||||||
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||||
|
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||||
|
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
path := versionalV2(CollectionCategory, CreateAction)
|
||||||
|
|
||||||
|
const baseRequestBody = `{
|
||||||
|
"collectionName": "doc_in_doc_out_demo",
|
||||||
|
"dimension": 2,
|
||||||
|
"idType": "Varchar",
|
||||||
|
"schema": {
|
||||||
|
"autoId": false,
|
||||||
|
"enableDynamicField": false,
|
||||||
|
"functions": [
|
||||||
|
{
|
||||||
|
"name": "bm25_fn_1",
|
||||||
|
"type": "BM25",
|
||||||
|
"inputFieldNames": ["document_content"],
|
||||||
|
"outputFieldNames": ["sparse_vector_1"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
testcase := requestBodyTestCase{
|
||||||
|
path: path,
|
||||||
|
requestBody: []byte(baseRequestBody),
|
||||||
|
errMsg: "functions are not supported for quickly create collection",
|
||||||
|
errCode: 1100,
|
||||||
|
}
|
||||||
|
|
||||||
|
sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutDescribeCollection(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||||
|
CollectionName: DefaultCollectionName,
|
||||||
|
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||||
|
ShardsNum: ShardNumDefault,
|
||||||
|
Status: &StatusSuccess,
|
||||||
|
}, nil).Once()
|
||||||
|
mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Once()
|
||||||
|
mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&DefaultDescIndexesReqp, nil).Once()
|
||||||
|
mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{
|
||||||
|
Status: &StatusSuccess,
|
||||||
|
Aliases: []string{DefaultAliasName},
|
||||||
|
}, nil).Once()
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
testcase := requestBodyTestCase{
|
||||||
|
path: versionalV2(CollectionCategory, DescribeAction),
|
||||||
|
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`),
|
||||||
|
}
|
||||||
|
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutInsert(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
// disable rate limit
|
||||||
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||||
|
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||||
|
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||||
|
CollectionName: DefaultCollectionName,
|
||||||
|
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||||
|
ShardsNum: ShardNumDefault,
|
||||||
|
Status: &StatusSuccess,
|
||||||
|
}, nil).Once()
|
||||||
|
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
|
||||||
|
|
||||||
|
testcase := requestBodyTestCase{
|
||||||
|
path: versionalV2(EntityCategory, InsertAction),
|
||||||
|
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "varchar_field": "some text"}]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutInsertInvalid(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
// disable rate limit
|
||||||
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||||
|
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||||
|
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||||
|
CollectionName: DefaultCollectionName,
|
||||||
|
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||||
|
ShardsNum: ShardNumDefault,
|
||||||
|
Status: &StatusSuccess,
|
||||||
|
}, nil).Once()
|
||||||
|
// invlaid insert request, will not be sent to proxy
|
||||||
|
|
||||||
|
testcase := requestBodyTestCase{
|
||||||
|
path: versionalV2(EntityCategory, InsertAction),
|
||||||
|
requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": {"1": 0.1}, "varchar_field": "some text"}]}`),
|
||||||
|
errCode: 1804,
|
||||||
|
errMsg: "not allowed to provide input data for function output field",
|
||||||
|
}
|
||||||
|
|
||||||
|
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocInDocOutSearch(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
// disable rate limit
|
||||||
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||||
|
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||||
|
|
||||||
|
mp := mocks.NewMockProxy(t)
|
||||||
|
testEngine := initHTTPServerV2(mp, false)
|
||||||
|
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||||
|
CollectionName: DefaultCollectionName,
|
||||||
|
Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64),
|
||||||
|
ShardsNum: ShardNumDefault,
|
||||||
|
Status: &StatusSuccess,
|
||||||
|
}, nil).Once()
|
||||||
|
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{
|
||||||
|
TopK: int64(3),
|
||||||
|
OutputFields: []string{FieldWordCount},
|
||||||
|
FieldsData: generateFieldData(),
|
||||||
|
Ids: generateIDs(schemapb.DataType_Int64, 3),
|
||||||
|
Scores: DefaultScores,
|
||||||
|
}}, nil).Once()
|
||||||
|
|
||||||
|
testcase := requestBodyTestCase{
|
||||||
|
path: versionalV2(EntityCategory, SearchAction),
|
||||||
|
requestBody: []byte(`{"collectionName": "book", "data": ["query data"], "limit": 4, "outputFields": ["word_count"]}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateCollection(t *testing.T) {
|
func TestCreateCollection(t *testing.T) {
|
||||||
paramtable.Init()
|
paramtable.Init()
|
||||||
// disable rate limit
|
// disable rate limit
|
||||||
@ -1054,7 +1294,6 @@ func TestMethodGet(t *testing.T) {
|
|||||||
if testcase.errCode != 0 {
|
if testcase.errCode != 0 {
|
||||||
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
||||||
}
|
}
|
||||||
fmt.Println(w.Body.String())
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -324,10 +324,20 @@ type FieldSchema struct {
|
|||||||
DefaultValue interface{} `json:"defaultValue" binding:"required"`
|
DefaultValue interface{} `json:"defaultValue" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FunctionSchema struct {
|
||||||
|
FunctionName string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
FunctionType string `json:"type" binding:"required"`
|
||||||
|
InputFieldNames []string `json:"inputFieldNames" binding:"required"`
|
||||||
|
OutputFieldNames []string `json:"outputFieldNames" binding:"required"`
|
||||||
|
Params map[string]interface{} `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
type CollectionSchema struct {
|
type CollectionSchema struct {
|
||||||
Fields []FieldSchema `json:"fields"`
|
Fields []FieldSchema `json:"fields"`
|
||||||
AutoId bool `json:"autoID"`
|
Functions []FunctionSchema `json:"functions"`
|
||||||
EnableDynamicField bool `json:"enableDynamicField"`
|
AutoId bool `json:"autoID"`
|
||||||
|
EnableDynamicField bool `json:"enableDynamicField"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CollectionReq struct {
|
type CollectionReq struct {
|
||||||
|
|||||||
@ -147,52 +147,77 @@ func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result)
|
|||||||
// --------------------- collection details --------------------- //
|
// --------------------- collection details --------------------- //
|
||||||
|
|
||||||
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||||
return printFieldDetails(fields, true)
|
var res []gin.H
|
||||||
|
for _, field := range fields {
|
||||||
|
fieldDetail := printFieldDetail(field, true)
|
||||||
|
res = append(res, fieldDetail)
|
||||||
|
}
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H {
|
func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H {
|
||||||
return printFieldDetails(fields, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H {
|
|
||||||
var res []gin.H
|
var res []gin.H
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
fieldDetail := gin.H{
|
fieldDetail := printFieldDetail(field, false)
|
||||||
HTTPReturnFieldName: field.Name,
|
|
||||||
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
|
|
||||||
HTTPReturnFieldPartitionKey: field.IsPartitionKey,
|
|
||||||
HTTPReturnFieldAutoID: field.AutoID,
|
|
||||||
HTTPReturnDescription: field.Description,
|
|
||||||
}
|
|
||||||
if typeutil.IsVectorType(field.DataType) {
|
|
||||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
|
||||||
if oldVersion {
|
|
||||||
dim, _ := getDim(field)
|
|
||||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
|
||||||
}
|
|
||||||
} else if field.DataType == schemapb.DataType_VarChar {
|
|
||||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
|
||||||
if oldVersion {
|
|
||||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
|
||||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
|
||||||
}
|
|
||||||
if !oldVersion {
|
|
||||||
fieldDetail[HTTPReturnFieldID] = field.FieldID
|
|
||||||
if field.TypeParams != nil {
|
|
||||||
fieldDetail[Params] = field.TypeParams
|
|
||||||
}
|
|
||||||
if field.DataType == schemapb.DataType_Array {
|
|
||||||
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res = append(res, fieldDetail)
|
res = append(res, fieldDetail)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func printFieldDetail(field *schemapb.FieldSchema, oldVersion bool) gin.H {
|
||||||
|
fieldDetail := gin.H{
|
||||||
|
HTTPReturnFieldName: field.Name,
|
||||||
|
HTTPReturnFieldPrimaryKey: field.IsPrimaryKey,
|
||||||
|
HTTPReturnFieldPartitionKey: field.IsPartitionKey,
|
||||||
|
HTTPReturnFieldAutoID: field.AutoID,
|
||||||
|
HTTPReturnDescription: field.Description,
|
||||||
|
}
|
||||||
|
if field.GetIsFunctionOutput() {
|
||||||
|
fieldDetail[HTTPReturnFieldIsFunctionOutput] = true
|
||||||
|
}
|
||||||
|
if typeutil.IsVectorType(field.DataType) {
|
||||||
|
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||||
|
if oldVersion {
|
||||||
|
dim, _ := getDim(field)
|
||||||
|
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||||
|
}
|
||||||
|
} else if field.DataType == schemapb.DataType_VarChar {
|
||||||
|
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||||
|
if oldVersion {
|
||||||
|
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||||
|
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||||
|
}
|
||||||
|
if !oldVersion {
|
||||||
|
fieldDetail[HTTPReturnFieldID] = field.FieldID
|
||||||
|
if field.TypeParams != nil {
|
||||||
|
fieldDetail[Params] = field.TypeParams
|
||||||
|
}
|
||||||
|
if field.DataType == schemapb.DataType_Array {
|
||||||
|
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fieldDetail
|
||||||
|
}
|
||||||
|
|
||||||
|
func printFunctionDetails(functions []*schemapb.FunctionSchema) []gin.H {
|
||||||
|
var res []gin.H
|
||||||
|
for _, function := range functions {
|
||||||
|
res = append(res, gin.H{
|
||||||
|
HTTPReturnFunctionName: function.Name,
|
||||||
|
HTTPReturnDescription: function.Description,
|
||||||
|
HTTPReturnFunctionType: function.Type,
|
||||||
|
HTTPReturnFunctionID: function.Id,
|
||||||
|
HTTPReturnFunctionInputFieldNames: function.InputFieldNames,
|
||||||
|
HTTPReturnFunctionOutputFieldNames: function.OutputFieldNames,
|
||||||
|
HTTPReturnFunctionParams: function.Params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
func getMetricType(pairs []*commonpb.KeyValuePair) string {
|
func getMetricType(pairs []*commonpb.KeyValuePair) string {
|
||||||
metricType := DefaultMetricType
|
metricType := DefaultMetricType
|
||||||
for _, pair := range pairs {
|
for _, pair := range pairs {
|
||||||
@ -258,6 +283,14 @@ func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if field is a function output field, user must not provide data for it
|
||||||
|
if field.GetIsFunctionOutput() {
|
||||||
|
if dataString != "" {
|
||||||
|
return merr.WrapErrParameterInvalid("", "not allowed to provide input data for function output field: "+fieldName), reallyDataArray, validDataMap
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch fieldType {
|
switch fieldType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
if dataString == "" {
|
if dataString == "" {
|
||||||
@ -626,11 +659,16 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
|||||||
nameColumns := make(map[string]interface{})
|
nameColumns := make(map[string]interface{})
|
||||||
nameDims := make(map[string]int64)
|
nameDims := make(map[string]int64)
|
||||||
fieldData := make(map[string]*schemapb.FieldData)
|
fieldData := make(map[string]*schemapb.FieldData)
|
||||||
|
|
||||||
for _, field := range sch.Fields {
|
for _, field := range sch.Fields {
|
||||||
// skip auto id pk field
|
// skip auto id pk field
|
||||||
if (field.IsPrimaryKey && field.AutoID) || field.IsDynamic {
|
if (field.IsPrimaryKey && field.AutoID) || field.IsDynamic {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// skip function output field
|
||||||
|
if field.GetIsFunctionOutput() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
var data interface{}
|
var data interface{}
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schemapb.DataType_Bool:
|
case schemapb.DataType_Bool:
|
||||||
@ -685,8 +723,8 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
|||||||
IsDynamic: field.IsDynamic,
|
IsDynamic: field.IsDynamic,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(nameDims) == 0 {
|
if len(nameDims) == 0 && len(sch.Functions) == 0 {
|
||||||
return nil, fmt.Errorf("collection: %s has no vector field", sch.Name)
|
return nil, fmt.Errorf("collection: %s has no vector field or functions", sch.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
dynamicCol := make([][]byte, 0, rowsLen)
|
dynamicCol := make([][]byte, 0, rowsLen)
|
||||||
@ -709,6 +747,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
|||||||
if (field.Nullable || field.DefaultValue != nil) && !ok {
|
if (field.Nullable || field.DefaultValue != nil) && !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if field.GetIsFunctionOutput() {
|
||||||
|
if ok {
|
||||||
|
return nil, fmt.Errorf("row %d has data provided for function output field %s", idx, field.Name)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
||||||
}
|
}
|
||||||
@ -1035,7 +1079,7 @@ func serializeSparseFloatVectors(vectors []gjson.Result, dataType schemapb.DataT
|
|||||||
return values, nil
|
return values, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
func convertQueries2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
||||||
var valueType commonpb.PlaceholderType
|
var valueType commonpb.PlaceholderType
|
||||||
var values [][]byte
|
var values [][]byte
|
||||||
var err error
|
var err error
|
||||||
@ -1055,6 +1099,12 @@ func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimensi
|
|||||||
case schemapb.DataType_SparseFloatVector:
|
case schemapb.DataType_SparseFloatVector:
|
||||||
valueType = commonpb.PlaceholderType_SparseFloatVector
|
valueType = commonpb.PlaceholderType_SparseFloatVector
|
||||||
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
|
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
valueType = commonpb.PlaceholderType_VarChar
|
||||||
|
res := gjson.Get(body, HTTPRequestData).Array()
|
||||||
|
for _, v := range res {
|
||||||
|
values = append(values, []byte(v.String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -23,6 +23,7 @@ const (
|
|||||||
FieldWordCount = "word_count"
|
FieldWordCount = "word_count"
|
||||||
FieldBookID = "book_id"
|
FieldBookID = "book_id"
|
||||||
FieldBookIntro = "book_intro"
|
FieldBookIntro = "book_intro"
|
||||||
|
FieldVarchar = "varchar_field"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DefaultScores = []float32{0.01, 0.04, 0.09}
|
var DefaultScores = []float32{0.01, 0.04, 0.09}
|
||||||
@ -74,17 +75,21 @@ func generateVectorFieldSchema(dataType schemapb.DataType) *schemapb.FieldSchema
|
|||||||
if dataType == schemapb.DataType_BinaryVector {
|
if dataType == schemapb.DataType_BinaryVector {
|
||||||
dim = "8"
|
dim = "8"
|
||||||
}
|
}
|
||||||
|
typeParams := []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.DimKey,
|
||||||
|
Value: dim,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if dataType == schemapb.DataType_SparseFloatVector {
|
||||||
|
typeParams = nil
|
||||||
|
}
|
||||||
return &schemapb.FieldSchema{
|
return &schemapb.FieldSchema{
|
||||||
FieldID: common.StartOfUserFieldID + int64(dataType),
|
FieldID: common.StartOfUserFieldID + int64(dataType),
|
||||||
IsPrimaryKey: false,
|
IsPrimaryKey: false,
|
||||||
DataType: dataType,
|
DataType: dataType,
|
||||||
AutoID: false,
|
AutoID: false,
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
TypeParams: typeParams,
|
||||||
{
|
|
||||||
Key: common.DimKey,
|
|
||||||
Value: dim,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,6 +115,44 @@ func generateCollectionSchema(primaryDataType schemapb.DataType) *schemapb.Colle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDocInDocOutCollectionSchema(primaryDataType schemapb.DataType) *schemapb.CollectionSchema {
|
||||||
|
primaryField := generatePrimaryField(primaryDataType)
|
||||||
|
vectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
|
||||||
|
vectorField.Name = FieldBookIntro
|
||||||
|
vectorField.IsFunctionOutput = true
|
||||||
|
return &schemapb.CollectionSchema{
|
||||||
|
Name: DefaultCollectionName,
|
||||||
|
Description: "",
|
||||||
|
AutoID: false,
|
||||||
|
Fields: []*schemapb.FieldSchema{
|
||||||
|
primaryField, {
|
||||||
|
FieldID: common.StartOfUserFieldID + 1,
|
||||||
|
Name: FieldWordCount,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: 5,
|
||||||
|
AutoID: false,
|
||||||
|
}, vectorField, {
|
||||||
|
FieldID: common.StartOfUserFieldID + 2,
|
||||||
|
Name: FieldVarchar,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_VarChar,
|
||||||
|
AutoID: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []*schemapb.FunctionSchema{
|
||||||
|
{
|
||||||
|
Name: "sum",
|
||||||
|
Type: schemapb.FunctionType_BM25,
|
||||||
|
InputFieldNames: []string{FieldVarchar},
|
||||||
|
OutputFieldNames: []string{FieldBookIntro},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
EnableDynamicField: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func generateIndexes() []*milvuspb.IndexDescription {
|
func generateIndexes() []*milvuspb.IndexDescription {
|
||||||
return []*milvuspb.IndexDescription{
|
return []*milvuspb.IndexDescription{
|
||||||
{
|
{
|
||||||
|
|||||||
@ -86,7 +86,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool {
|
|||||||
var consistencyLevel commonpb.ConsistencyLevel
|
var consistencyLevel commonpb.ConsistencyLevel
|
||||||
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
useDefaultConsistency := t.request.GetUseDefaultConsistency()
|
||||||
if !useDefaultConsistency {
|
if !useDefaultConsistency {
|
||||||
// legacy SDK & resultful behavior
|
// legacy SDK & restful behavior
|
||||||
if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 {
|
if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -373,7 +373,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||||
// set PartitionIDs for sub search
|
// set PartitionIDs for sub search
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
// isolatioin has tighter constraint, check first
|
// isolation has tighter constraint, check first
|
||||||
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
||||||
if mvErr != nil {
|
if mvErr != nil {
|
||||||
return mvErr
|
return mvErr
|
||||||
@ -453,7 +453,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||||||
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||||
|
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
// isolatioin has tighter constraint, check first
|
// isolation has tighter constraint, check first
|
||||||
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan)
|
||||||
if mvErr != nil {
|
if mvErr != nil {
|
||||||
return mvErr
|
return mvErr
|
||||||
|
|||||||
@ -296,7 +296,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
defer sd.lifetime.Done()
|
defer sd.lifetime.Done()
|
||||||
|
|
||||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||||
log.Warn("deletgator received search request not belongs to it",
|
log.Warn("delegator received search request not belongs to it",
|
||||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||||
)
|
)
|
||||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user