diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index d02c45419d..a07bdbec65 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -95,14 +95,22 @@ const ( HTTPReturnHas = "has" - HTTPReturnFieldName = "name" - HTTPReturnFieldID = "id" - HTTPReturnFieldType = "type" - HTTPReturnFieldPrimaryKey = "primaryKey" - HTTPReturnFieldPartitionKey = "partitionKey" - HTTPReturnFieldAutoID = "autoId" - HTTPReturnFieldElementType = "elementType" - HTTPReturnDescription = "description" + HTTPReturnFieldName = "name" + HTTPReturnFieldID = "id" + HTTPReturnFieldType = "type" + HTTPReturnFieldPrimaryKey = "primaryKey" + HTTPReturnFieldPartitionKey = "partitionKey" + HTTPReturnFieldAutoID = "autoId" + HTTPReturnFieldElementType = "elementType" + HTTPReturnDescription = "description" + HTTPReturnFieldIsFunctionOutput = "isFunctionOutput" + + HTTPReturnFunctionName = "name" + HTTPReturnFunctionID = "id" + HTTPReturnFunctionType = "type" + HTTPReturnFunctionInputFieldNames = "inputFieldNames" + HTTPReturnFunctionOutputFieldNames = "outputFieldNames" + HTTPReturnFunctionParams = "params" HTTPReturnIndexMetricType = "metricType" HTTPReturnIndexType = "indexType" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 610c114d05..9de12ec0ec 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -437,6 +437,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a HTTPReturnDescription: coll.Schema.Description, HTTPReturnFieldAutoID: autoID, "fields": printFieldsV2(coll.Schema.Fields), + "functions": printFunctionDetails(coll.Schema.Functions), "aliases": aliases, "indexes": indexDesc, "load": collLoadState, @@ -897,7 +898,21 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche if !typeutil.IsSparseFloatVectorType(vectorField.DataType) { dim, _ = getDim(vectorField) } - phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim) + + dataType := vectorField.DataType + + if vectorField.GetIsFunctionOutput() { + for _, function := range collSchema.Functions { + if function.Type == schemapb.FunctionType_BM25 { + // TODO: currently only BM25 function is supported, thus guarantees one input field to one output field + if function.OutputFieldNames[0] == vectorField.Name { + dataType = schemapb.DataType_VarChar + } + } + } + } + + phv, err := convertQueries2Placeholder(body, dataType, dim) if err != nil { return nil, err } @@ -1086,6 +1101,17 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe fieldNames := map[string]bool{} partitionsNum := int64(-1) if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + if len(httpReq.Schema.Functions) > 0 { + err := merr.WrapErrParameterInvalid("schema", "functions", + "functions are not supported for quickly create collection") + log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } + if httpReq.Dimension == 0 { err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName", "dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")") @@ -1162,8 +1188,40 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe Name: httpReq.CollectionName, AutoID: httpReq.Schema.AutoId, Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, EnableDynamicField: httpReq.Schema.EnableDynamicField, } + + allOutputFields := []string{} + + for _, function := range httpReq.Schema.Functions { + functionTypeValue, ok := schemapb.FunctionType_value[function.FunctionType] + if !ok { + log.Ctx(ctx).Warn("function's data type is invalid(case sensitive).", zap.Any("function.DataType", function.FunctionType), zap.Any("function", function)) + err := merr.WrapErrParameterInvalid("FunctionType", function.FunctionType, "function data type is invalid(case sensitive)") + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrParameterInvalid), + HTTPReturnMessage: err.Error(), + }) + return nil, err + } + functionType := schemapb.FunctionType(functionTypeValue) + description := function.Description + params := []*commonpb.KeyValuePair{} + for key, value := range function.Params { + params = append(params, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + collSchema.Functions = append(collSchema.Functions, &schemapb.FunctionSchema{ + Name: function.FunctionName, + Description: description, + Type: functionType, + InputFieldNames: function.InputFieldNames, + OutputFieldNames: function.OutputFieldNames, + Params: params, + }) + allOutputFields = append(allOutputFields, function.OutputFieldNames...) + } + for _, field := range httpReq.Schema.Fields { fieldDataType, ok := schemapb.DataType_value[field.DataType] if !ok { @@ -1218,6 +1276,9 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe for key, fieldParam := range field.ElementTypeParams { fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", fieldParam)}) } + if lo.Contains(allOutputFields, field.FieldName) { + fieldSchema.IsFunctionOutput = true + } collSchema.Fields = append(collSchema.Fields, &fieldSchema) fieldNames[field.FieldName] = true } diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index feee4f6969..67e876bada 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -57,6 +57,22 @@ func init() { paramtable.Init() } +func sendReqAndVerify(t *testing.T, testEngine *gin.Engine, testName, method string, testcase requestBodyTestCase) { + t.Run(testName, func(t *testing.T) { + req := httptest.NewRequest(method, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + if testcase.errCode != 0 { + assert.Contains(t, returnBody.Message, testcase.errMsg) + } + }) +} + func TestHTTPWrapper(t *testing.T) { postTestCases := []requestBodyTestCase{} postTestCasesTrace := []requestBodyTestCase{} @@ -468,6 +484,230 @@ func TestDatabaseWrapper(t *testing.T) { } } +func TestDocInDocOutCreateCollection(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(1) + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(CollectionCategory, CreateAction) + + const baseRequestBody = `{ + "collectionName": "doc_in_doc_out_demo", + "schema": { + "autoId": false, + "enableDynamicField": false, + "fields": [ + { + "fieldName": "my_id", + "dataType": "Int64", + "isPrimary": true + }, + { + "fieldName": "document_content", + "dataType": "VarChar", + "elementTypeParams": { + "max_length": "9000" + } + }, + { + "fieldName": "sparse_vector_1", + "dataType": "SparseFloatVector" + } + ], + "functions": %s + } + }` + + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(fmt.Sprintf(baseRequestBody, `[ + { + "name": "bm25_fn_1", + "type": "BM25", + "inputFieldNames": ["document_content"], + "outputFieldNames": ["sparse_vector_1"] + } + ]`)), + }) + + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(fmt.Sprintf(baseRequestBody, `[ + { + "name": "bm25_fn_1", + "type": "BM25_", + "inputFieldNames": ["document_content"], + "outputFieldNames": ["sparse_vector_1"] + } + ]`)), + errMsg: "actual=BM25_", + errCode: 1100, + }) + + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(fmt.Sprintf(baseRequestBody, `[ + { + "name": "bm25_fn_1", + "inputFieldNames": ["document_content"], + "outputFieldNames": ["sparse_vector_1"] + } + ]`)), + errMsg: "actual=", // unprovided function type is empty string + errCode: 1100, + }) + + for _, testcase := range postTestCases { + sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase) + } +} + +func TestDocInDocOutCreateCollectionQuickDisallowFunction(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + mp := mocks.NewMockProxy(t) + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(CollectionCategory, CreateAction) + + const baseRequestBody = `{ + "collectionName": "doc_in_doc_out_demo", + "dimension": 2, + "idType": "Varchar", + "schema": { + "autoId": false, + "enableDynamicField": false, + "functions": [ + { + "name": "bm25_fn_1", + "type": "BM25", + "inputFieldNames": ["document_content"], + "outputFieldNames": ["sparse_vector_1"] + } + ] + } + }` + + testcase := requestBodyTestCase{ + path: path, + requestBody: []byte(baseRequestBody), + errMsg: "functions are not supported for quickly create collection", + errCode: 1100, + } + + sendReqAndVerify(t, testEngine, "post"+testcase.path, http.MethodPost, testcase) +} + +func TestDocInDocOutDescribeCollection(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Once() + mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&DefaultDescIndexesReqp, nil).Once() + mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{ + Status: &StatusSuccess, + Aliases: []string{DefaultAliasName}, + }, nil).Once() + testEngine := initHTTPServerV2(mp, false) + testcase := requestBodyTestCase{ + path: versionalV2(CollectionCategory, DescribeAction), + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + } + sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase) +} + +func TestDocInDocOutInsert(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + mp := mocks.NewMockProxy(t) + testEngine := initHTTPServerV2(mp, false) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() + + testcase := requestBodyTestCase{ + path: versionalV2(EntityCategory, InsertAction), + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "varchar_field": "some text"}]}`), + } + + sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase) +} + +func TestDocInDocOutInsertInvalid(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + mp := mocks.NewMockProxy(t) + testEngine := initHTTPServerV2(mp, false) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + // invlaid insert request, will not be sent to proxy + + testcase := requestBodyTestCase{ + path: versionalV2(EntityCategory, InsertAction), + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": {"1": 0.1}, "varchar_field": "some text"}]}`), + errCode: 1804, + errMsg: "not allowed to provide input data for function output field", + } + + sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase) +} + +func TestDocInDocOutSearch(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + mp := mocks.NewMockProxy(t) + testEngine := initHTTPServerV2(mp, false) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateDocInDocOutCollectionSchema(schemapb.DataType_Int64), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{ + TopK: int64(3), + OutputFields: []string{FieldWordCount}, + FieldsData: generateFieldData(), + Ids: generateIDs(schemapb.DataType_Int64, 3), + Scores: DefaultScores, + }}, nil).Once() + + testcase := requestBodyTestCase{ + path: versionalV2(EntityCategory, SearchAction), + requestBody: []byte(`{"collectionName": "book", "data": ["query data"], "limit": 4, "outputFields": ["word_count"]}`), + } + + sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase) +} + func TestCreateCollection(t *testing.T) { paramtable.Init() // disable rate limit @@ -1054,7 +1294,6 @@ func TestMethodGet(t *testing.T) { if testcase.errCode != 0 { assert.Equal(t, testcase.errMsg, returnBody.Message) } - fmt.Println(w.Body.String()) }) } } diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 6f37730d8e..d73bcfca1e 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -324,10 +324,20 @@ type FieldSchema struct { DefaultValue interface{} `json:"defaultValue" binding:"required"` } +type FunctionSchema struct { + FunctionName string `json:"name" binding:"required"` + Description string `json:"description"` + FunctionType string `json:"type" binding:"required"` + InputFieldNames []string `json:"inputFieldNames" binding:"required"` + OutputFieldNames []string `json:"outputFieldNames" binding:"required"` + Params map[string]interface{} `json:"params"` +} + type CollectionSchema struct { - Fields []FieldSchema `json:"fields"` - AutoId bool `json:"autoID"` - EnableDynamicField bool `json:"enableDynamicField"` + Fields []FieldSchema `json:"fields"` + Functions []FunctionSchema `json:"functions"` + AutoId bool `json:"autoID"` + EnableDynamicField bool `json:"enableDynamicField"` } type CollectionReq struct { diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index b27314c9de..63ec81049d 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -147,52 +147,77 @@ func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result) // --------------------- collection details --------------------- // func printFields(fields []*schemapb.FieldSchema) []gin.H { - return printFieldDetails(fields, true) + var res []gin.H + for _, field := range fields { + fieldDetail := printFieldDetail(field, true) + res = append(res, fieldDetail) + } + return res } func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H { - return printFieldDetails(fields, false) -} - -func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H { var res []gin.H for _, field := range fields { - fieldDetail := gin.H{ - HTTPReturnFieldName: field.Name, - HTTPReturnFieldPrimaryKey: field.IsPrimaryKey, - HTTPReturnFieldPartitionKey: field.IsPartitionKey, - HTTPReturnFieldAutoID: field.AutoID, - HTTPReturnDescription: field.Description, - } - if typeutil.IsVectorType(field.DataType) { - fieldDetail[HTTPReturnFieldType] = field.DataType.String() - if oldVersion { - dim, _ := getDim(field) - fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")" - } - } else if field.DataType == schemapb.DataType_VarChar { - fieldDetail[HTTPReturnFieldType] = field.DataType.String() - if oldVersion { - maxLength, _ := parameterutil.GetMaxLength(field) - fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")" - } - } else { - fieldDetail[HTTPReturnFieldType] = field.DataType.String() - } - if !oldVersion { - fieldDetail[HTTPReturnFieldID] = field.FieldID - if field.TypeParams != nil { - fieldDetail[Params] = field.TypeParams - } - if field.DataType == schemapb.DataType_Array { - fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String() - } - } + fieldDetail := printFieldDetail(field, false) res = append(res, fieldDetail) } return res } +func printFieldDetail(field *schemapb.FieldSchema, oldVersion bool) gin.H { + fieldDetail := gin.H{ + HTTPReturnFieldName: field.Name, + HTTPReturnFieldPrimaryKey: field.IsPrimaryKey, + HTTPReturnFieldPartitionKey: field.IsPartitionKey, + HTTPReturnFieldAutoID: field.AutoID, + HTTPReturnDescription: field.Description, + } + if field.GetIsFunctionOutput() { + fieldDetail[HTTPReturnFieldIsFunctionOutput] = true + } + if typeutil.IsVectorType(field.DataType) { + fieldDetail[HTTPReturnFieldType] = field.DataType.String() + if oldVersion { + dim, _ := getDim(field) + fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")" + } + } else if field.DataType == schemapb.DataType_VarChar { + fieldDetail[HTTPReturnFieldType] = field.DataType.String() + if oldVersion { + maxLength, _ := parameterutil.GetMaxLength(field) + fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")" + } + } else { + fieldDetail[HTTPReturnFieldType] = field.DataType.String() + } + if !oldVersion { + fieldDetail[HTTPReturnFieldID] = field.FieldID + if field.TypeParams != nil { + fieldDetail[Params] = field.TypeParams + } + if field.DataType == schemapb.DataType_Array { + fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String() + } + } + return fieldDetail +} + +func printFunctionDetails(functions []*schemapb.FunctionSchema) []gin.H { + var res []gin.H + for _, function := range functions { + res = append(res, gin.H{ + HTTPReturnFunctionName: function.Name, + HTTPReturnDescription: function.Description, + HTTPReturnFunctionType: function.Type, + HTTPReturnFunctionID: function.Id, + HTTPReturnFunctionInputFieldNames: function.InputFieldNames, + HTTPReturnFunctionOutputFieldNames: function.OutputFieldNames, + HTTPReturnFunctionParams: function.Params, + }) + } + return res +} + func getMetricType(pairs []*commonpb.KeyValuePair) string { metricType := DefaultMetricType for _, pair := range pairs { @@ -258,6 +283,14 @@ func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, continue } + // if field is a function output field, user must not provide data for it + if field.GetIsFunctionOutput() { + if dataString != "" { + return merr.WrapErrParameterInvalid("", "not allowed to provide input data for function output field: "+fieldName), reallyDataArray, validDataMap + } + continue + } + switch fieldType { case schemapb.DataType_FloatVector: if dataString == "" { @@ -626,11 +659,16 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, nameColumns := make(map[string]interface{}) nameDims := make(map[string]int64) fieldData := make(map[string]*schemapb.FieldData) + for _, field := range sch.Fields { // skip auto id pk field if (field.IsPrimaryKey && field.AutoID) || field.IsDynamic { continue } + // skip function output field + if field.GetIsFunctionOutput() { + continue + } var data interface{} switch field.DataType { case schemapb.DataType_Bool: @@ -685,8 +723,8 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, IsDynamic: field.IsDynamic, } } - if len(nameDims) == 0 { - return nil, fmt.Errorf("collection: %s has no vector field", sch.Name) + if len(nameDims) == 0 && len(sch.Functions) == 0 { + return nil, fmt.Errorf("collection: %s has no vector field or functions", sch.Name) } dynamicCol := make([][]byte, 0, rowsLen) @@ -709,6 +747,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, if (field.Nullable || field.DefaultValue != nil) && !ok { continue } + if field.GetIsFunctionOutput() { + if ok { + return nil, fmt.Errorf("row %d has data provided for function output field %s", idx, field.Name) + } + continue + } if !ok { return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name) } @@ -1035,7 +1079,7 @@ func serializeSparseFloatVectors(vectors []gjson.Result, dataType schemapb.DataT return values, nil } -func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) { +func convertQueries2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) { var valueType commonpb.PlaceholderType var values [][]byte var err error @@ -1055,6 +1099,12 @@ func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimensi case schemapb.DataType_SparseFloatVector: valueType = commonpb.PlaceholderType_SparseFloatVector values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType) + case schemapb.DataType_VarChar: + valueType = commonpb.PlaceholderType_VarChar + res := gjson.Get(body, HTTPRequestData).Array() + for _, v := range res { + values = append(values, []byte(v.String())) + } } if err != nil { return nil, err diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index d6070ed9aa..6ebdeb4d69 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -23,6 +23,7 @@ const ( FieldWordCount = "word_count" FieldBookID = "book_id" FieldBookIntro = "book_intro" + FieldVarchar = "varchar_field" ) var DefaultScores = []float32{0.01, 0.04, 0.09} @@ -74,17 +75,21 @@ func generateVectorFieldSchema(dataType schemapb.DataType) *schemapb.FieldSchema if dataType == schemapb.DataType_BinaryVector { dim = "8" } + typeParams := []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: dim, + }, + } + if dataType == schemapb.DataType_SparseFloatVector { + typeParams = nil + } return &schemapb.FieldSchema{ FieldID: common.StartOfUserFieldID + int64(dataType), IsPrimaryKey: false, DataType: dataType, AutoID: false, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: dim, - }, - }, + TypeParams: typeParams, } } @@ -110,6 +115,44 @@ func generateCollectionSchema(primaryDataType schemapb.DataType) *schemapb.Colle } } +func generateDocInDocOutCollectionSchema(primaryDataType schemapb.DataType) *schemapb.CollectionSchema { + primaryField := generatePrimaryField(primaryDataType) + vectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector) + vectorField.Name = FieldBookIntro + vectorField.IsFunctionOutput = true + return &schemapb.CollectionSchema{ + Name: DefaultCollectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + primaryField, { + FieldID: common.StartOfUserFieldID + 1, + Name: FieldWordCount, + IsPrimaryKey: false, + Description: "", + DataType: 5, + AutoID: false, + }, vectorField, { + FieldID: common.StartOfUserFieldID + 2, + Name: FieldVarchar, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_VarChar, + AutoID: false, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "sum", + Type: schemapb.FunctionType_BM25, + InputFieldNames: []string{FieldVarchar}, + OutputFieldNames: []string{FieldBookIntro}, + }, + }, + EnableDynamicField: true, + } +} + func generateIndexes() []*milvuspb.IndexDescription { return []*milvuspb.IndexDescription{ { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 9bf90be527..f9728aa9a3 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -86,7 +86,7 @@ func (t *searchTask) CanSkipAllocTimestamp() bool { var consistencyLevel commonpb.ConsistencyLevel useDefaultConsistency := t.request.GetUseDefaultConsistency() if !useDefaultConsistency { - // legacy SDK & resultful behavior + // legacy SDK & restful behavior if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 { return true } @@ -373,7 +373,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { internalSubReq.FieldId = queryInfo.GetQueryFieldId() // set PartitionIDs for sub search if t.partitionKeyMode { - // isolatioin has tighter constraint, check first + // isolation has tighter constraint, check first mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan) if mvErr != nil { return mvErr @@ -453,7 +453,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { t.SearchRequest.FieldId = queryInfo.GetQueryFieldId() if t.partitionKeyMode { - // isolatioin has tighter constraint, check first + // isolation has tighter constraint, check first mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan) if mvErr != nil { return mvErr diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 417f65bd75..524bce4fb5 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -296,7 +296,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { - log.Warn("deletgator received search request not belongs to it", + log.Warn("delegator received search request not belongs to it", zap.Strings("reqChannels", req.GetDmlChannels()), ) return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())