diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 9980a41a19..545328aa27 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -553,7 +553,7 @@ func (h *HandlersV1) query(c *gin.Context) { } else { queryResp := response.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS, nil) if err != nil { log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ @@ -632,7 +632,7 @@ func (h *HandlersV1) get(c *gin.Context) { } else { queryResp := response.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS, nil) if err != nil { log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ @@ -1006,7 +1006,7 @@ func (h *HandlersV1) search(c *gin.Context) { HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(searchResp.Results.TopK, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) + outputData, err := buildQueryResp(searchResp.Results.TopK, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS, nil) if err != nil { log.Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 05550ddd1d..af08ac1a8f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -790,6 +790,11 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa QueryParams: []*commonpb.KeyValuePair{}, UseDefaultConsistency: true, } + var err error + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } req.ExprTemplateValues = generateExpressionTemplate(httpReq.ExprParams) c.Set(ContextRequest, req) if httpReq.Offset > 0 { @@ -804,7 +809,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa if err == nil { queryResp := resp.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with query result", zap.Any("response", resp), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ @@ -852,7 +857,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName if err == nil { queryResp := resp.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with get result", zap.Any("response", resp), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ @@ -1181,7 +1186,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) + outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ @@ -1286,7 +1291,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) - outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) + outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) HTTPReturn(c, http.StatusOK, gin.H{ diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 462da7ebe4..c75037b2be 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -2059,6 +2059,12 @@ func TestDML(t *testing.T) { paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateCollectionSchema(schemapb.DataType_Int64, false, true), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Times(3) mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ CollectionName: DefaultCollectionName, Schema: generateCollectionSchema(schemapb.DataType_Int64, false, true), diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 3e8e0826af..998db7a33a 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1244,7 +1244,9 @@ func genDynamicFields(fields []string, list []*schemapb.FieldData) []string { return dynamicFields } -func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, scores []float32, enableInt64 bool) ([]map[string]interface{}, error) { +func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, + scores []float32, enableInt64 bool, collectionSchema *schemapb.CollectionSchema, +) ([]map[string]interface{}, error) { columnNum := len(fieldDataList) if rowsNum == int64(0) { // always if columnNum > 0 { @@ -1302,6 +1304,17 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap } queryResp := make([]map[string]interface{}, 0, rowsNum) dynamicOutputFields := genDynamicFields(needFields, fieldDataList) + + pkFieldName := DefaultPrimaryFieldName + if collectionSchema != nil { + fieldsSchema := collectionSchema.GetFields() + for _, field := range fieldsSchema { + if field.GetIsPrimaryKey() { + pkFieldName = field.GetName() + break + } + } + } for i := int64(0); i < rowsNum; i++ { row := map[string]interface{}{} if columnNum > 0 { @@ -1420,13 +1433,13 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap case *schemapb.IDs_IntId: int64Pks := ids.GetIntId().GetData() if enableInt64 { - row[DefaultPrimaryFieldName] = int64Pks[i] + row[pkFieldName] = int64Pks[i] } else { - row[DefaultPrimaryFieldName] = strconv.FormatInt(int64Pks[i], 10) + row[pkFieldName] = strconv.FormatInt(int64Pks[i], 10) } case *schemapb.IDs_StrId: stringPks := ids.GetStrId().GetData() - row[DefaultPrimaryFieldName] = stringPks[i] + row[pkFieldName] = stringPks[i] default: return nil, errors.New("the type of primary key(id) is not supported, use other sdk please") } diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index ab4e923ccf..a8d2f0f79a 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1264,7 +1264,7 @@ func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, c func TestBuildQueryResp(t *testing.T) { outputFields := []string{FieldBookID, FieldWordCount, "author", "date"} - rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3} + rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true, nil) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3} assert.Equal(t, nil, err) exceptRows := generateSearchResult(schemapb.DataType_Int64) assert.Equal(t, true, compareRows(rows, exceptRows, compareRow)) @@ -2163,7 +2163,7 @@ func TestBuildQueryResps(t *testing.T) { outputFields := []string{"XXX", "YYY"} outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}} for _, theOutputFields := range outputFieldsList { - rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) + rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true, nil) assert.Equal(t, nil, err) exceptRows := newSearchResult(generateSearchResult(schemapb.DataType_Int64)) assert.Equal(t, true, compareRows(rows, exceptRows, compareRow)) @@ -2177,30 +2177,30 @@ func TestBuildQueryResps(t *testing.T) { schemapb.DataType_JSON, schemapb.DataType_Array, } for _, dateType := range dataTypes { - _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) + _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true, nil) assert.Equal(t, nil, err) } - _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) + _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true, nil) assert.Equal(t, "the type(1000) of field(wrong-field-type) is not supported, use other sdk please", err.Error()) - res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true) + res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, true, nil) assert.Equal(t, 3, len(res)) assert.Equal(t, nil, err) - res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false) + res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false, nil) assert.Equal(t, 3, len(res)) assert.Equal(t, nil, err) - res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_VarChar, 3), DefaultScores, true) + res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIDs(schemapb.DataType_VarChar, 3), DefaultScores, true, nil) assert.Equal(t, 3, len(res)) assert.Equal(t, nil, err) - _, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false) + _, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIDs(schemapb.DataType_Int64, 3), DefaultScores, false, nil) assert.Equal(t, nil, err) // len(rows) != len(scores), didn't show distance - _, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true) + _, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIDs(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true, nil) assert.Equal(t, nil, err) } diff --git a/tests/python_client/Dockerfile b/tests/python_client/Dockerfile index 24fad77b31..531067eab3 100644 --- a/tests/python_client/Dockerfile +++ b/tests/python_client/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-buster +FROM python:3.10-bullseye RUN apt-get update && apt-get install -y jq