diff --git a/internal/proxy/highlighter.go b/internal/proxy/highlighter.go index 4262579a31..f304c82bed 100644 --- a/internal/proxy/highlighter.go +++ b/internal/proxy/highlighter.go @@ -14,9 +14,12 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proxy/shardclient" + "github.com/milvus-io/milvus/internal/util/function/highlight" + "github.com/milvus-io/milvus/internal/util/function/models" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) const ( @@ -413,3 +416,58 @@ func buildStringFragments(task *highlightTask, idx int, frags []*querypb.Highlig } return result } + +type SemanticHighlighter struct { + highlight *highlight.SemanticHighlight +} + +func newSemanticHighlighter(t *searchTask, extraInfo *models.ModelExtraInfo) (*SemanticHighlighter, error) { + conf := paramtable.Get().FunctionCfg.ZillizProviders.GetValue() + highlight, err := highlight.NewSemanticHighlight(t.schema.CollectionSchema, t.request.GetHighlighter().GetParams(), conf, extraInfo) + if err != nil { + return nil, err + } + return &SemanticHighlighter{highlight: highlight}, nil +} + +func (h *SemanticHighlighter) FieldIDs() []int64 { + return h.highlight.FieldIDs() +} + +func (h *SemanticHighlighter) AsSearchPipelineOperator(t *searchTask) (operator, error) { + return &semanticHighlightOperator{highlight: h.highlight}, nil +} + +type semanticHighlightOperator struct { + highlight *highlight.SemanticHighlight +} + +func (op *semanticHighlightOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + result := inputs[0].(*milvuspb.SearchResults) + datas := result.Results.GetFieldsData() + if len(datas) == 0 { + return []any{result}, nil + } + highlightResults := []*commonpb.HighlightResult{} + for _, fieldID := range op.highlight.FieldIDs() { + fieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldId == fieldID }) + if !ok { + return nil, errors.Errorf("get highlight failed, text field not in output field %d", fieldID) + } + texts := fieldDatas.GetScalars().GetStringData().GetData() + highlights, err := op.highlight.Process(ctx, result.Results.GetTopks(), texts, nil) + if err != nil { + return nil, err + } + singeFieldHighlights := &commonpb.HighlightResult{ + FieldName: fieldDatas.FieldName, + Datas: make([]*commonpb.HighlightData, 0, len(highlights)), + } + for _, highlight := range highlights { + singeFieldHighlights.Datas = append(singeFieldHighlights.Datas, &commonpb.HighlightData{Fragments: highlight}) + } + highlightResults = append(highlightResults, singeFieldHighlights) + } + result.Results.HighlightResults = highlightResults + return []any{result}, nil +} diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go index bd618fcdbe..54486655e1 100644 --- a/internal/proxy/search_pipeline_test.go +++ b/internal/proxy/search_pipeline_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proxy/shardclient" + "github.com/milvus-io/milvus/internal/util/function/highlight" "github.com/milvus-io/milvus/internal/util/function/models" "github.com/milvus-io/milvus/internal/util/function/rerank" "github.com/milvus-io/milvus/internal/util/segcore" @@ -362,6 +363,251 @@ func (s *SearchPipelineSuite) TestHighlightOp() { s.NoError(err) } +func (s *SearchPipelineSuite) TestSemanticHighlightOp() { + ctx := context.Background() + + // Mock SemanticHighlight methods + mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To( + func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) { + return [][]string{ + {"highlighted text 1"}, + {"highlighted text 2"}, + {"highlighted text 3"}, + }, nil + }).Build() + defer mockProcess.UnPatch() + + mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).To(func(h *highlight.SemanticHighlight) []int64 { + return []int64{101} + }).Build() + defer mockFieldIDs.UnPatch() + + // Create operator + op := &semanticHighlightOperator{ + highlight: &highlight.SemanticHighlight{}, + } + + // Create search results with text data + searchResults := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 101, + FieldName: testVarCharField, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"text 1", "text 2", "text 3"}, + }, + }, + }, + }, + }, + }, + }, + } + + // Run the operator + results, err := op.run(ctx, s.span, searchResults) + s.NoError(err) + s.NotNil(results) + s.Len(results, 1) + + // Verify results + result := results[0].(*milvuspb.SearchResults) + s.NotNil(result.Results.HighlightResults) + s.Len(result.Results.HighlightResults, 1) + + highlightResult := result.Results.HighlightResults[0] + s.Equal(testVarCharField, highlightResult.FieldName) + s.Len(highlightResult.Datas, 3) + s.Equal([]string{"highlighted text 1"}, highlightResult.Datas[0].Fragments) + s.Equal([]string{"highlighted text 2"}, highlightResult.Datas[1].Fragments) + s.Equal([]string{"highlighted text 3"}, highlightResult.Datas[2].Fragments) +} + +func (s *SearchPipelineSuite) TestSemanticHighlightOpMissingField() { + ctx := context.Background() + + // Mock FieldIDs to return field 999 (not in results) + mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{999}).Build() + defer mockFieldIDs.UnPatch() + + op := &semanticHighlightOperator{ + highlight: &highlight.SemanticHighlight{}, + } + + // Create search results without the expected field + searchResults := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + Topks: []int64{1}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 101, + FieldName: testVarCharField, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"text 1"}, + }, + }, + }, + }, + }, + }, + }, + } + + // Run the operator and expect error + _, err := op.run(ctx, s.span, searchResults) + s.Error(err) + s.Contains(err.Error(), "text field not in output field") +} + +func (s *SearchPipelineSuite) TestSemanticHighlightOpMultipleFields() { + ctx := context.Background() + + // Use a counter to return different results for different calls + callCount := 0 + mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To( + func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) { + callCount++ + return [][]string{ + {fmt.Sprintf("highlighted text field%d-1", callCount)}, + {fmt.Sprintf("highlighted text field%d-2", callCount)}, + }, nil + }).Build() + defer mockProcess.UnPatch() + + mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101, 102}).Build() + defer mockFieldIDs.UnPatch() + + op := &semanticHighlightOperator{ + highlight: &highlight.SemanticHighlight{}, + } + + // Create search results with multiple text fields + searchResults := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 2, + Topks: []int64{2}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 101, + FieldName: "field1", + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"text 1", "text 2"}, + }, + }, + }, + }, + }, + { + FieldId: 102, + FieldName: "field2", + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"another text 1", "another text 2"}, + }, + }, + }, + }, + }, + }, + }, + } + + // Run the operator + results, err := op.run(ctx, s.span, searchResults) + s.NoError(err) + s.NotNil(results) + + // Verify results + result := results[0].(*milvuspb.SearchResults) + s.NotNil(result.Results.HighlightResults) + s.Len(result.Results.HighlightResults, 2) + + // Verify first field + s.Equal("field1", result.Results.HighlightResults[0].FieldName) + s.Len(result.Results.HighlightResults[0].Datas, 2) + + // Verify second field + s.Equal("field2", result.Results.HighlightResults[1].FieldName) + s.Len(result.Results.HighlightResults[1].Datas, 2) +} + +func (s *SearchPipelineSuite) TestSemanticHighlightOpEmptyResults() { + ctx := context.Background() + + // Mock Process to return empty results + mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To( + func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) { + return [][]string{}, nil + }).Build() + defer mockProcess.UnPatch() + + mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101}).Build() + defer mockFieldIDs.UnPatch() + + op := &semanticHighlightOperator{ + highlight: &highlight.SemanticHighlight{}, + } + + // Create empty search results + searchResults := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 0, + Topks: []int64{0}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 101, + FieldName: testVarCharField, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{}, + }, + }, + }, + }, + }, + }, + }, + } + + // Run the operator + results, err := op.run(ctx, s.span, searchResults) + s.NoError(err) + s.NotNil(results) + + // Verify results + result := results[0].(*milvuspb.SearchResults) + s.NotNil(result.Results.HighlightResults) + s.Len(result.Results.HighlightResults, 1) + s.Equal(testVarCharField, result.Results.HighlightResults[0].FieldName) + s.Len(result.Results.HighlightResults[0].Datas, 0) +} + func (s *SearchPipelineSuite) TestSearchPipeline() { collectionName := "test" task := &searchTask{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 70b82d76d1..6da6e22f5d 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -618,6 +618,13 @@ func (t *searchTask) addHighlightTask(highlighter *commonpb.Highlighter, metricT switch highlighter.GetType() { case commonpb.HighlightType_Lexical: return t.createLexicalHighlighter(highlighter, metricType, annsField, placeholder, analyzerName) + case commonpb.HighlightType_Semantic: + h, err := newSemanticHighlighter(t, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()}) + if err != nil { + return merr.WrapErrParameterInvalidMsg("Create SemanticHighlight failed: %v ", err) + } + t.highlighter = h + return nil default: return merr.WrapErrParameterInvalidMsg("unsupported highlight type: %v", highlighter.GetType()) } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index cf05a6d06f..4c15710b25 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -17,6 +17,7 @@ package proxy import ( "context" + "encoding/json" "fmt" "math" "strconv" @@ -43,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/function/embedding" + "github.com/milvus-io/milvus/internal/util/function/highlight" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/common" @@ -4920,7 +4922,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { }, } - schemaInfo := newSchemaInfo(schema) + info := newSchemaInfo(schema) placeholder := &commonpb.PlaceholderGroup{ Placeholders: []*commonpb.PlaceholderValue{{ @@ -4934,7 +4936,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("lexical highlight success", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -4954,7 +4956,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("Lexical highlight with custom tags", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -4975,7 +4977,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("lexical highlight with wrong metric type", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, SearchRequest: &internalpb.SearchRequest{}, request: &milvuspb.SearchRequest{}, } @@ -4991,7 +4993,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("lexical highlight with invalid pre_tags type", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -5015,9 +5017,9 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { }, } - schemaInfo := newSchemaInfo(schemaWithoutBM25) + info := newSchemaInfo(schemaWithoutBM25) task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -5031,7 +5033,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("highlight without highlight search text", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -5045,7 +5047,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("highlight with invalid highlight search key", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -5059,7 +5061,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { t.Run("highlight with unknown type", func(t *testing.T) { task := &searchTask{ - schema: schemaInfo, + schema: info, } highlighter := &commonpb.Highlighter{ @@ -5070,4 +5072,23 @@ func TestSearchTask_AddHighlightTask(t *testing.T) { err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "") assert.Error(t, err) }) + + t.Run("semantic highlight success", func(t *testing.T) { + task := &searchTask{ + schema: info, + } + + queriesJSON, _ := json.Marshal([]string{"test_query"}) + inputFieldsJSON, _ := json.Marshal([]string{"text_field"}) + + highlighter := &commonpb.Highlighter{ + Type: commonpb.HighlightType_Semantic, + Params: []*commonpb.KeyValuePair{{Key: "queries", Value: string(queriesJSON)}, {Key: "input_fields", Value: string(inputFieldsJSON)}}, + } + + mockSemanticHighlight := mockey.Mock(highlight.NewSemanticHighlight).Return(&highlight.SemanticHighlight{}, nil).Build() + defer mockSemanticHighlight.UnPatch() + task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "") + require.NotNil(t, task.highlighter) + }) } diff --git a/internal/util/function/highlight/semantic_highlight.go b/internal/util/function/highlight/semantic_highlight.go new file mode 100644 index 0000000000..170cebac48 --- /dev/null +++ b/internal/util/function/highlight/semantic_highlight.go @@ -0,0 +1,146 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package highlight + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/models" +) + +type semanticHighlightProvider interface { + highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) + maxBatch() int +} + +type baseSemanticHighlightProvider struct { + batchSize int +} + +func (provider *baseSemanticHighlightProvider) maxBatch() int { + return provider.batchSize +} + +type SemanticHighlight struct { + fieldIDs []int64 + provider semanticHighlightProvider + queries []string +} + +const ( + queryKeyName string = "queries" + inputFieldKeyName string = "input_fields" +) + +func NewSemanticHighlight(collSchema *schemapb.CollectionSchema, params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (*SemanticHighlight, error) { + queries := []string{} + inputField := []string{} + for _, param := range params { + switch param.Key { + case queryKeyName: + if err := json.Unmarshal([]byte(param.Value), &queries); err != nil { + return nil, fmt.Errorf("Parse queries failed, err: %v", err) + } + case inputFieldKeyName: + if err := json.Unmarshal([]byte(param.Value), &inputField); err != nil { + return nil, fmt.Errorf("Parse input_field failed, err: %v", err) + } + } + } + + if len(queries) == 0 { + return nil, fmt.Errorf("queries is required") + } + + if len(inputField) == 0 { + return nil, fmt.Errorf("input_field is required") + } + + fieldIDMap := make(map[string]*schemapb.FieldSchema) + for _, field := range collSchema.Fields { + fieldIDMap[field.Name] = field + } + + fieldIDs := []int64{} + for _, fieldName := range inputField { + field, ok := fieldIDMap[fieldName] + if !ok { + return nil, fmt.Errorf("input_field %s not found", fieldName) + } + if field.DataType != schemapb.DataType_VarChar && field.DataType != schemapb.DataType_Text { + return nil, fmt.Errorf("input_field %s is not a VarChar or Text field", fieldName) + } + + fieldIDs = append(fieldIDs, field.FieldID) + } + + // TODO: support other providers if have more providers + provider, err := newZillizHighlightProvider(params, conf, extraInfo) + if err != nil { + return nil, err + } + + return &SemanticHighlight{fieldIDs: fieldIDs, provider: provider, queries: queries}, nil +} + +func (highlight *SemanticHighlight) FieldIDs() []int64 { + return highlight.fieldIDs +} + +func (highlight *SemanticHighlight) processOneQuery(ctx context.Context, query string, data []string, params map[string]string) ([][]string, error) { + if len(data) == 0 { + return [][]string{}, nil + } + highlights, err := highlight.provider.highlight(ctx, query, data, params) + if err != nil { + return nil, err + } + if len(highlights) != len(data) { + return nil, fmt.Errorf("Highlights size must equal to data size, but got highlights size [%d], data size [%d]", len(highlights), len(data)) + } + return highlights, nil +} + +func (highlight *SemanticHighlight) Process(ctx context.Context, topks []int64, data []string, params map[string]string) ([][]string, error) { + nq := len(topks) + if len(highlight.queries) != nq { + return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", nq, len(highlight.queries), highlight.queries) + } + if len(data) == 0 { + return [][]string{}, nil + } + + highlights := make([][]string, 0, len(data)) + start := int64(0) + + for i, query := range highlight.queries { + size := topks[i] + singleHighlights, err := highlight.processOneQuery(ctx, query, data[start:start+size], params) + if err != nil { + return nil, err + } + highlights = append(highlights, singleHighlights...) + start += size + } + return highlights, nil +} diff --git a/internal/util/function/highlight/semantic_highlight_test.go b/internal/util/function/highlight/semantic_highlight_test.go new file mode 100644 index 0000000000..12a1cb1f5f --- /dev/null +++ b/internal/util/function/highlight/semantic_highlight_test.go @@ -0,0 +1,545 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package highlight + +import ( + "context" + "encoding/json" + "testing" + + "github.com/bytedance/mockey" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/models" + "github.com/milvus-io/milvus/internal/util/function/models/zilliz" +) + +func TestSemanticHighlight(t *testing.T) { + suite.Run(t, new(SemanticHighlightSuite)) +} + +type SemanticHighlightSuite struct { + suite.Suite + schema *schemapb.CollectionSchema +} + +func (s *SemanticHighlightSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "title", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "content", DataType: schemapb.DataType_Text}, + {FieldID: 103, Name: "description", DataType: schemapb.DataType_VarChar}, + {FieldID: 104, Name: "embedding", DataType: schemapb.DataType_FloatVector}, + }, + } +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_Success() { + queries := []string{"machine learning", "artificial intelligence"} + inputFields := []string{"title", "content"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.NoError(err) + s.NotNil(highlight) + s.Equal([]int64{101, 102}, highlight.FieldIDs()) + s.Equal(queries, highlight.queries) +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingQueries() { + inputFields := []string{"title"} + inputFieldsJSON, _ := json.Marshal(inputFields) + + params := []*commonpb.KeyValuePair{ + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "queries is required") +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingInputFields() { + queries := []string{"machine learning"} + queriesJSON, _ := json.Marshal(queries) + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "input_field is required") +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidQueriesJSON() { + inputFields := []string{"title"} + inputFieldsJSON, _ := json.Marshal(inputFields) + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: "invalid json"}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "Parse queries failed") +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidInputFieldsJSON() { + queries := []string{"machine learning"} + queriesJSON, _ := json.Marshal(queries) + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: "invalid json"}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "Parse input_field failed") +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_FieldNotFound() { + queries := []string{"machine learning"} + inputFields := []string{"nonexistent_field"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "not found") +} + +func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidFieldType() { + queries := []string{"machine learning"} + inputFields := []string{"embedding"} // FloatVector, not VarChar or Text + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + + s.Error(err) + s.Nil(highlight) + s.Contains(err.Error(), "is not a VarChar or Text field") +} + +func (s *SemanticHighlightSuite) TestProcessOneQuery_Success() { + queries := []string{"machine learning"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + expectedHighlights := [][]string{ + {"machine learning"}, + {"machine"}, + } + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) { + return expectedHighlights, nil + }).Build() + defer mock2.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{"Machine learning is a subset of AI", "Machine learning is powerful"} + highlights, err := highlight.processOneQuery(ctx, "machine learning", data, nil) + + s.NoError(err) + s.Equal(expectedHighlights, highlights) +} + +func (s *SemanticHighlightSuite) TestProcessOneQuery_Error() { + queries := []string{"test query"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + expectedError := errors.New("highlight service error") + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) { + return nil, expectedError + }).Build() + defer mock2.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{"test document"} + highlights, err := highlight.processOneQuery(ctx, "test query", data, nil) + + s.Error(err) + s.Nil(highlights) + s.Equal(expectedError, err) +} + +func (s *SemanticHighlightSuite) TestProcess_Success() { + queries := []string{"machine learning", "deep learning"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + expectedHighlights1 := [][]string{ + {"machine learning", "deep learning"}, + } + expectedHighlights2 := [][]string{ + {"deep learning", "machine learning"}, + } + + callCount := 0 + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, query string, _ []string, _ map[string]string) ([][]string, error) { + callCount++ + if query == "machine learning" { + return expectedHighlights1, nil + } + return expectedHighlights2, nil + }).Build() + defer mock2.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{"Machine learning document", "Deep learning document"} + highlights, err := highlight.Process(ctx, []int64{1, 1}, data, nil) + + s.NoError(err) + s.NotNil(highlights) + s.Equal(2, callCount, "Should call highlight twice for two queries") +} + +func (s *SemanticHighlightSuite) TestProcess_NqMismatch() { + queries := []string{"machine learning"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{"test document"} + highlights, err := highlight.Process(ctx, []int64{1, 1, 1}, data, nil) // nq=3 but queries has only 1 + + s.Error(err) + s.Nil(highlights) + s.Contains(err.Error(), "nq must equal to queries size") +} + +func (s *SemanticHighlightSuite) TestProcess_ProviderError() { + queries := []string{"test query"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + expectedError := errors.New("provider error") + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) { + return nil, expectedError + }).Build() + defer mock2.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{"test document"} + highlights, err := highlight.Process(ctx, []int64{1}, data, nil) + + s.Error(err) + s.Nil(highlights) + s.Equal(expectedError, err) +} + +func (s *SemanticHighlightSuite) TestProcess_EmptyData() { + queries := []string{"test query", "test query 2", "test query 3"} + inputFields := []string{"title"} + + queriesJSON, _ := json.Marshal(queries) + inputFieldsJSON, _ := json.Marshal(inputFields) + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, texts []string, _ map[string]string) ([][]string, error) { + return [][]string{texts}, nil + }).Build() + defer mock2.UnPatch() + + params := []*commonpb.KeyValuePair{ + {Key: queryKeyName, Value: string(queriesJSON)}, + {Key: inputFieldKeyName, Value: string(inputFieldsJSON)}, + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + } + + conf := map[string]string{ + "endpoint": "localhost:8080", + } + + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo) + s.NoError(err) + + ctx := context.Background() + data := []string{} + highlights, err := highlight.Process(ctx, []int64{0, 0, 0}, data, nil) + + s.NoError(err) + s.NotNil(highlights) + + data2 := []string{"test document"} + + highlights2, err := highlight.Process(ctx, []int64{0, 1, 0}, data2, nil) + + s.NoError(err) + s.Equal(1, len(highlights2)) + s.Equal([][]string{{"test document"}}, highlights2) +} + +func (s *SemanticHighlightSuite) TestBaseSemanticHighlightProvider_MaxBatch() { + provider := &baseSemanticHighlightProvider{batchSize: 128} + s.Equal(128, provider.maxBatch()) + + provider2 := &baseSemanticHighlightProvider{batchSize: 32} + s.Equal(32, provider2.maxBatch()) +} diff --git a/internal/util/function/highlight/zilliz_highlight_provider.go b/internal/util/function/highlight/zilliz_highlight_provider.go new file mode 100644 index 0000000000..0aa29b212d --- /dev/null +++ b/internal/util/function/highlight/zilliz_highlight_provider.go @@ -0,0 +1,78 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package highlight + +import ( + "context" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/util/function/models" + "github.com/milvus-io/milvus/internal/util/function/models/zilliz" +) + +type zillizHighlightProvider struct { + baseSemanticHighlightProvider + client *zilliz.ZillizClient + modelName string + queries []string + + modelParams map[string]string +} + +func newZillizHighlightProvider(params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (semanticHighlightProvider, error) { + var modelDeploymentID string + var err error + maxBatch := 64 + modelParams := map[string]string{} + for _, param := range params { + switch strings.ToLower(param.Key) { + case models.ModelDeploymentIDKey: + modelDeploymentID = param.Value + case models.MaxClientBatchSizeParamKey: + if maxBatch, err = strconv.Atoi(param.Value); err != nil { + return nil, err + } + + default: + modelParams[param.Key] = param.Value + } + } + + c, err := zilliz.NewZilliClient(modelDeploymentID, extraInfo.ClusterID, extraInfo.DBName, conf) + if err != nil { + return nil, err + } + + provider := zillizHighlightProvider{ + baseSemanticHighlightProvider: baseSemanticHighlightProvider{batchSize: maxBatch}, + client: c, + modelParams: modelParams, + } + return &provider, nil +} + +func (h *zillizHighlightProvider) highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) { + highlights, err := h.client.Highlight(ctx, query, texts, params) + if err != nil { + return nil, err + } + return highlights, nil +} diff --git a/internal/util/function/highlight/zilliz_highlight_provider_test.go b/internal/util/function/highlight/zilliz_highlight_provider_test.go new file mode 100644 index 0000000000..d85e4e8290 --- /dev/null +++ b/internal/util/function/highlight/zilliz_highlight_provider_test.go @@ -0,0 +1,240 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package highlight + +import ( + "context" + "testing" + + "github.com/bytedance/mockey" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/util/function/models" + "github.com/milvus-io/milvus/internal/util/function/models/zilliz" +) + +func TestZillizHighlightProvider(t *testing.T) { + suite.Run(t, new(ZillizHighlightProviderSuite)) +} + +type ZillizHighlightProviderSuite struct { + suite.Suite +} + +func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_Success() { + tests := []struct { + name string + params []*commonpb.KeyValuePair + conf map[string]string + extraInfo *models.ModelExtraInfo + expectedBatch int + expectedParams map[string]string + }{ + { + name: "basic configuration", + params: []*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + }, + conf: map[string]string{ + "endpoint": "localhost:8080", + }, + extraInfo: &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + }, + expectedBatch: 64, // default batch size + expectedParams: map[string]string{}, + }, + { + name: "with custom batch size", + params: []*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + {Key: models.MaxClientBatchSizeParamKey, Value: "32"}, + }, + conf: map[string]string{ + "endpoint": "localhost:8080", + }, + extraInfo: &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + }, + expectedBatch: 32, + expectedParams: map[string]string{}, + }, + { + name: "with additional model parameters", + params: []*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + {Key: models.MaxClientBatchSizeParamKey, Value: "16"}, + {Key: "threshold", Value: "0.7"}, + }, + conf: map[string]string{ + "endpoint": "localhost:8080", + }, + extraInfo: &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + }, + expectedBatch: 16, + expectedParams: map[string]string{ + "threshold": "0.7", + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // Note: This test will fail with actual connection since we can't easily mock the global client manager + // But we can test the parameter parsing logic by checking the error message + provider, err := newZillizHighlightProvider(tt.params, tt.conf, tt.extraInfo) + + // Since we can't easily mock the zilliz client creation, we expect a connection error + // but we can verify that the parameters were parsed correctly by checking the error doesn't relate to parameter parsing + if err != nil { + // Connection errors are expected in unit tests + s.Contains(err.Error(), "Connect model serving failed", "Expected connection error, got: %v", err) + } else { + // If somehow the connection succeeds, verify the provider was created correctly + s.NotNil(provider) + zillizProvider, ok := provider.(*zillizHighlightProvider) + s.True(ok) + s.Equal(tt.expectedBatch, zillizProvider.maxBatch()) + s.Equal(tt.expectedParams, zillizProvider.modelParams) + } + }) + } +} + +func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_InvalidBatchSize() { + tests := []struct { + name string + batchValue string + expectedErr string + }{ + { + name: "invalid number format", + batchValue: "not_a_number", + expectedErr: "invalid syntax", + }, + { + name: "empty batch size", + batchValue: "", + expectedErr: "invalid syntax", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + params := []*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + {Key: models.MaxClientBatchSizeParamKey, Value: tt.batchValue}, + } + conf := map[string]string{ + "endpoint": "localhost:8080", + } + extraInfo := &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + } + + provider, err := newZillizHighlightProvider(params, conf, extraInfo) + s.Error(err) + s.Nil(provider) + s.Contains(err.Error(), tt.expectedErr) + }) + } +} + +func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Success() { + // Set up expected behavior + ctx := context.Background() + query := "machine learning" + texts := []string{ + "Machine learning is a subset of artificial intelligence.", + "Deep learning is a type of machine learning.", + "Natural language processing uses machine learning techniques.", + } + params := map[string]string{"param1": "value1"} + expectedHighlights := [][]string{ + {"Machine learning", "artificial intelligence"}, + {"Deep learning", "machine learning"}, + {"Natural language processing", "machine learning"}, + } + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) { + return expectedHighlights, nil + }).Build() + defer mock2.UnPatch() + + provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + }, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + }) + s.NoError(err) + s.NotNil(provider) + + highlights, err := provider.highlight(ctx, query, texts, params) + + s.NoError(err) + s.Equal(expectedHighlights, highlights) +} + +func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Error() { + // Set up expected behavior with error + ctx := context.Background() + query := "test query" + texts := []string{"doc1", "doc2", "doc3"} + params := map[string]string{"param1": "value1"} + expectedError := errors.New("highlight service error") + + mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) { + return &zilliz.ZillizClient{}, nil + }).Build() + defer mock1.UnPatch() + + mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) { + return nil, expectedError + }).Build() + defer mock2.UnPatch() + + provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{ + {Key: models.ModelDeploymentIDKey, Value: "test-deployment"}, + }, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{ + ClusterID: "test-cluster", + DBName: "test-db", + }) + s.NoError(err) + s.NotNil(provider) + + // Test the highlight method + highlights, err := provider.highlight(ctx, query, texts, params) + + s.Error(err) + s.Nil(highlights) + s.Equal(expectedError, err) +} diff --git a/internal/util/function/models/zilliz/zilliz_client.go b/internal/util/function/models/zilliz/zilliz_client.go index 4d9ef78fef..d3d1b74d93 100644 --- a/internal/util/function/models/zilliz/zilliz_client.go +++ b/internal/util/function/models/zilliz/zilliz_client.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/connectivity" @@ -33,6 +34,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" + "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/modelservicepb" ) @@ -93,7 +95,10 @@ func (m *clientManager) GetConn(clientConf *clientConfig) (*grpc.ClientConn, err if m.conn != nil { if clientConf.endpoint != m.config.endpoint { - _ = m.conn.Close() + err := m.conn.Close() + if err != nil { + log.Warn("Close connect failed", zap.String("endpoint", m.config.endpoint), zap.Error(err)) + } m.conn = nil } else { state := m.conn.GetState() @@ -247,3 +252,22 @@ func (c *ZillizClient) Rerank(ctx context.Context, query string, texts []string, } return res.Scores, nil } + +func (c *ZillizClient) Highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) { + stub := modelservicepb.NewHighlightServiceClient(c.conn) + req := &modelservicepb.HighlightRequest{ + Query: query, + Documents: texts, + Params: params, + } + ctx = c.setMeta(ctx) + res, err := stub.Highlight(ctx, req) + if err != nil { + return nil, err + } + highlights := make([][]string, 0, len(res.GetResults())) + for _, ret := range res.GetResults() { + highlights = append(highlights, ret.GetSentences()) + } + return highlights, nil +} diff --git a/internal/util/function/models/zilliz/zilliz_client_test.go b/internal/util/function/models/zilliz/zilliz_client_test.go index d9dc1f2af5..089f6ddba5 100644 --- a/internal/util/function/models/zilliz/zilliz_client_test.go +++ b/internal/util/function/models/zilliz/zilliz_client_test.go @@ -63,6 +63,19 @@ func (m *mockRerankServer) Rerank(ctx context.Context, req *modelservicepb.TextR return m.response, nil } +type mockHighlightServer struct { + modelservicepb.UnimplementedHighlightServiceServer + response *modelservicepb.HighlightResponse + err error +} + +func (m *mockHighlightServer) Highlight(ctx context.Context, req *modelservicepb.HighlightRequest) (*modelservicepb.HighlightResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + func TestLoadConfig(t *testing.T) { tests := []struct { name string @@ -559,3 +572,103 @@ func TestZillizClient_Embedding_EmptyResponse(t *testing.T) { assert.NoError(t, err) assert.Empty(t, embeddings) } + +func TestZillizClient_Highlight(t *testing.T) { + // Setup mock server + s, lis, dialer := setupMockServer(t) + defer lis.Close() + defer s.Stop() + + mockServer := &mockHighlightServer{ + response: &modelservicepb.HighlightResponse{ + Status: &modelservicepb.Status{Code: 0, Msg: "success"}, + Results: []*modelservicepb.HighlightResult{ + { + Sentences: []string{"highlight1", "highlight2"}, + }, + { + Sentences: []string{"highlight3", "highlight4"}, + }, + }, + }, + } + + modelservicepb.RegisterHighlightServiceServer(s, mockServer) + + go func() { + if err := s.Serve(lis); err != nil { + fmt.Printf("Server exited with error: %v\n", err) + } + }() + + // Create connection + conn, err := grpc.DialContext( + context.Background(), + "bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.NoError(t, err) + defer conn.Close() + + client := &ZillizClient{ + modelDeploymentID: "test-deployment", + clusterID: "test-cluster", + conn: conn, + } + + // Test successful highlight + ctx := context.Background() + query := "test query" + texts := []string{"doc1", "doc2", "doc3"} + params := map[string]string{"param1": "value1"} + highlights, err := client.Highlight(ctx, query, texts, params) + assert.NoError(t, err) + assert.Equal(t, [][]string{{"highlight1", "highlight2"}, {"highlight3", "highlight4"}}, highlights) +} + +func TestZillizClient_Highlight_Error(t *testing.T) { + // Setup mock server with error + s, lis, dialer := setupMockServer(t) + defer lis.Close() + defer s.Stop() + + mockServer := &mockHighlightServer{ + err: assert.AnError, + } + + modelservicepb.RegisterHighlightServiceServer(s, mockServer) + + go func() { + if err := s.Serve(lis); err != nil { + fmt.Printf("Server exited with error: %v\n", err) + } + }() + + // Create connection + conn, err := grpc.DialContext( + context.Background(), + "bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.NoError(t, err) + defer conn.Close() + + client := &ZillizClient{ + modelDeploymentID: "test-deployment", + clusterID: "test-cluster", + conn: conn, + } + + // Test highlight with error + ctx := context.Background() + query := "test query" + texts := []string{"doc1", "doc2", "doc3"} + params := map[string]string{"param1": "value1"} + highlights, err := client.Highlight(ctx, query, texts, params) + assert.Error(t, err) + assert.Nil(t, highlights) +} diff --git a/pkg/proto/model_service.proto b/pkg/proto/model_service.proto index a3001b8af6..e1ee993079 100644 --- a/pkg/proto/model_service.proto +++ b/pkg/proto/model_service.proto @@ -11,6 +11,10 @@ service RerankService { rpc Rerank(TextRerankRequest) returns (TextRerankResponse); } +service HighlightService { + rpc Highlight(HighlightRequest) returns (HighlightResponse); +} + message Status { int32 code = 1; string msg = 2; @@ -68,3 +72,20 @@ message TextRerankResponse { map extra_info = 2; repeated float scores = 3; } + +message HighlightRequest { + string model = 1; + string query = 2; + repeated string documents = 3; + map params = 4; +} + +message HighlightResult { + repeated string sentences = 1; +} + +message HighlightResponse { + Status status = 1; + map extra_info = 2; + repeated HighlightResult results = 3; +} diff --git a/pkg/proto/modelservicepb/model_service.pb.go b/pkg/proto/modelservicepb/model_service.pb.go index 2d26b7adbd..b58990e314 100644 --- a/pkg/proto/modelservicepb/model_service.pb.go +++ b/pkg/proto/modelservicepb/model_service.pb.go @@ -620,6 +620,187 @@ func (x *TextRerankResponse) GetScores() []float32 { return nil } +type HighlightRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Model string `protobuf:"bytes,1,opt,name=model,proto3" json:"model,omitempty"` + Query string `protobuf:"bytes,2,opt,name=query,proto3" json:"query,omitempty"` + Documents []string `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"` + Params map[string]string `protobuf:"bytes,4,rep,name=params,proto3" json:"params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (x *HighlightRequest) Reset() { + *x = HighlightRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_model_service_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HighlightRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HighlightRequest) ProtoMessage() {} + +func (x *HighlightRequest) ProtoReflect() protoreflect.Message { + mi := &file_model_service_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HighlightRequest.ProtoReflect.Descriptor instead. +func (*HighlightRequest) Descriptor() ([]byte, []int) { + return file_model_service_proto_rawDescGZIP(), []int{9} +} + +func (x *HighlightRequest) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *HighlightRequest) GetQuery() string { + if x != nil { + return x.Query + } + return "" +} + +func (x *HighlightRequest) GetDocuments() []string { + if x != nil { + return x.Documents + } + return nil +} + +func (x *HighlightRequest) GetParams() map[string]string { + if x != nil { + return x.Params + } + return nil +} + +type HighlightResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Sentences []string `protobuf:"bytes,1,rep,name=sentences,proto3" json:"sentences,omitempty"` +} + +func (x *HighlightResult) Reset() { + *x = HighlightResult{} + if protoimpl.UnsafeEnabled { + mi := &file_model_service_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HighlightResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HighlightResult) ProtoMessage() {} + +func (x *HighlightResult) ProtoReflect() protoreflect.Message { + mi := &file_model_service_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HighlightResult.ProtoReflect.Descriptor instead. +func (*HighlightResult) Descriptor() ([]byte, []int) { + return file_model_service_proto_rawDescGZIP(), []int{10} +} + +func (x *HighlightResult) GetSentences() []string { + if x != nil { + return x.Sentences + } + return nil +} + +type HighlightResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status *Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` + ExtraInfo map[string]string `protobuf:"bytes,2,rep,name=extra_info,json=extraInfo,proto3" json:"extra_info,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Results []*HighlightResult `protobuf:"bytes,3,rep,name=results,proto3" json:"results,omitempty"` +} + +func (x *HighlightResponse) Reset() { + *x = HighlightResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_model_service_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HighlightResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HighlightResponse) ProtoMessage() {} + +func (x *HighlightResponse) ProtoReflect() protoreflect.Message { + mi := &file_model_service_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HighlightResponse.ProtoReflect.Descriptor instead. +func (*HighlightResponse) Descriptor() ([]byte, []int) { + return file_model_service_proto_rawDescGZIP(), []int{11} +} + +func (x *HighlightResponse) GetStatus() *Status { + if x != nil { + return x.Status + } + return nil +} + +func (x *HighlightResponse) GetExtraInfo() map[string]string { + if x != nil { + return x.ExtraInfo + } + return nil +} + +func (x *HighlightResponse) GetResults() []*HighlightResult { + if x != nil { + return x.Results + } + return nil +} + var File_model_service_proto protoreflect.FileDescriptor var file_model_service_proto_rawDesc = []byte{ @@ -729,27 +910,72 @@ var file_model_service_proto_rawDesc = []byte{ 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65, - 0x78, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x12, 0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, - 0x2f, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, - 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, - 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x30, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, - 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, - 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x32, 0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e, - 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, - 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, - 0x72, 0x61, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xe8, 0x01, 0x0a, 0x10, 0x48, 0x69, + 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, + 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x6f, + 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x64, + 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x4f, 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61, + 0x6d, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, + 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x52, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x39, 0x0a, 0x0b, 0x50, 0x61, 0x72, + 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2f, 0x0a, 0x0f, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, + 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x65, 0x6e, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x6e, 0x74, + 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0xae, 0x02, 0x0a, 0x11, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, + 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, - 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, - 0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, - 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, - 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, + 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x5a, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f, + 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x3b, 0x2e, 0x6d, 0x69, 0x6c, + 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x45, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e, + 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e, + 0x66, 0x6f, 0x12, 0x44, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, + 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, + 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x1a, 0x3c, 0x0a, 0x0e, 0x45, 0x78, 0x74, 0x72, + 0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65, 0x78, 0x74, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, + 0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x2f, 0x2e, 0x6d, + 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, + 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e, + 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, + 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d, + 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, + 0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e, 0x6d, 0x69, 0x6c, + 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e, + 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, + 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x7a, 0x0a, 0x10, 0x48, 0x69, 0x67, 0x68, 0x6c, + 0x69, 0x67, 0x68, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x66, 0x0a, 0x09, 0x48, + 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x12, 0x2b, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, + 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2c, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, + 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, + 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -765,7 +991,7 @@ func file_model_service_proto_rawDescGZIP() []byte { } var file_model_service_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 18) var file_model_service_proto_goTypes = []interface{}{ (DenseVector_DType)(0), // 0: milvus.proto.modelservice.DenseVector.DType (*Status)(nil), // 1: milvus.proto.modelservice.Status @@ -777,33 +1003,44 @@ var file_model_service_proto_goTypes = []interface{}{ (*TextEmbeddingResponse)(nil), // 7: milvus.proto.modelservice.TextEmbeddingResponse (*TextRerankRequest)(nil), // 8: milvus.proto.modelservice.TextRerankRequest (*TextRerankResponse)(nil), // 9: milvus.proto.modelservice.TextRerankResponse - nil, // 10: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry - nil, // 11: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry - nil, // 12: milvus.proto.modelservice.TextRerankRequest.ParamsEntry - nil, // 13: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry + (*HighlightRequest)(nil), // 10: milvus.proto.modelservice.HighlightRequest + (*HighlightResult)(nil), // 11: milvus.proto.modelservice.HighlightResult + (*HighlightResponse)(nil), // 12: milvus.proto.modelservice.HighlightResponse + nil, // 13: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry + nil, // 14: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry + nil, // 15: milvus.proto.modelservice.TextRerankRequest.ParamsEntry + nil, // 16: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry + nil, // 17: milvus.proto.modelservice.HighlightRequest.ParamsEntry + nil, // 18: milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry } var file_model_service_proto_depIdxs = []int32{ 0, // 0: milvus.proto.modelservice.DenseVector.dtype:type_name -> milvus.proto.modelservice.DenseVector.DType 2, // 1: milvus.proto.modelservice.MultiVector.token_vectors:type_name -> milvus.proto.modelservice.DenseVector - 10, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry + 13, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry 2, // 3: milvus.proto.modelservice.EmbeddingResult.dense:type_name -> milvus.proto.modelservice.DenseVector 3, // 4: milvus.proto.modelservice.EmbeddingResult.sparse:type_name -> milvus.proto.modelservice.SparseVector 4, // 5: milvus.proto.modelservice.EmbeddingResult.multi_vector:type_name -> milvus.proto.modelservice.MultiVector 1, // 6: milvus.proto.modelservice.TextEmbeddingResponse.status:type_name -> milvus.proto.modelservice.Status - 11, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry + 14, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry 6, // 8: milvus.proto.modelservice.TextEmbeddingResponse.results:type_name -> milvus.proto.modelservice.EmbeddingResult - 12, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry + 15, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry 1, // 10: milvus.proto.modelservice.TextRerankResponse.status:type_name -> milvus.proto.modelservice.Status - 13, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry - 5, // 12: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest - 8, // 13: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest - 7, // 14: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse - 9, // 15: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse - 14, // [14:16] is the sub-list for method output_type - 12, // [12:14] is the sub-list for method input_type - 12, // [12:12] is the sub-list for extension type_name - 12, // [12:12] is the sub-list for extension extendee - 0, // [0:12] is the sub-list for field type_name + 16, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry + 17, // 12: milvus.proto.modelservice.HighlightRequest.params:type_name -> milvus.proto.modelservice.HighlightRequest.ParamsEntry + 1, // 13: milvus.proto.modelservice.HighlightResponse.status:type_name -> milvus.proto.modelservice.Status + 18, // 14: milvus.proto.modelservice.HighlightResponse.extra_info:type_name -> milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry + 11, // 15: milvus.proto.modelservice.HighlightResponse.results:type_name -> milvus.proto.modelservice.HighlightResult + 5, // 16: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest + 8, // 17: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest + 10, // 18: milvus.proto.modelservice.HighlightService.Highlight:input_type -> milvus.proto.modelservice.HighlightRequest + 7, // 19: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse + 9, // 20: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse + 12, // 21: milvus.proto.modelservice.HighlightService.Highlight:output_type -> milvus.proto.modelservice.HighlightResponse + 19, // [19:22] is the sub-list for method output_type + 16, // [16:19] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_model_service_proto_init() } @@ -920,6 +1157,42 @@ func file_model_service_proto_init() { return nil } } + file_model_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HighlightRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_model_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HighlightResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_model_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HighlightResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -927,9 +1200,9 @@ func file_model_service_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_model_service_proto_rawDesc, NumEnums: 1, - NumMessages: 13, + NumMessages: 18, NumExtensions: 0, - NumServices: 2, + NumServices: 3, }, GoTypes: file_model_service_proto_goTypes, DependencyIndexes: file_model_service_proto_depIdxs, diff --git a/pkg/proto/modelservicepb/model_service_grpc.pb.go b/pkg/proto/modelservicepb/model_service_grpc.pb.go index 1a22d3082e..17e3a480c6 100644 --- a/pkg/proto/modelservicepb/model_service_grpc.pb.go +++ b/pkg/proto/modelservicepb/model_service_grpc.pb.go @@ -193,3 +193,91 @@ var RerankService_ServiceDesc = grpc.ServiceDesc{ Streams: []grpc.StreamDesc{}, Metadata: "model_service.proto", } + +const ( + HighlightService_Highlight_FullMethodName = "/milvus.proto.modelservice.HighlightService/Highlight" +) + +// HighlightServiceClient is the client API for HighlightService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type HighlightServiceClient interface { + Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error) +} + +type highlightServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewHighlightServiceClient(cc grpc.ClientConnInterface) HighlightServiceClient { + return &highlightServiceClient{cc} +} + +func (c *highlightServiceClient) Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error) { + out := new(HighlightResponse) + err := c.cc.Invoke(ctx, HighlightService_Highlight_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// HighlightServiceServer is the server API for HighlightService service. +// All implementations should embed UnimplementedHighlightServiceServer +// for forward compatibility +type HighlightServiceServer interface { + Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error) +} + +// UnimplementedHighlightServiceServer should be embedded to have forward compatible implementations. +type UnimplementedHighlightServiceServer struct { +} + +func (UnimplementedHighlightServiceServer) Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Highlight not implemented") +} + +// UnsafeHighlightServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to HighlightServiceServer will +// result in compilation errors. +type UnsafeHighlightServiceServer interface { + mustEmbedUnimplementedHighlightServiceServer() +} + +func RegisterHighlightServiceServer(s grpc.ServiceRegistrar, srv HighlightServiceServer) { + s.RegisterService(&HighlightService_ServiceDesc, srv) +} + +func _HighlightService_Highlight_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HighlightRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HighlightServiceServer).Highlight(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HighlightService_Highlight_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HighlightServiceServer).Highlight(ctx, req.(*HighlightRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// HighlightService_ServiceDesc is the grpc.ServiceDesc for HighlightService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var HighlightService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "milvus.proto.modelservice.HighlightService", + HandlerType: (*HighlightServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Highlight", + Handler: _HighlightService_Highlight_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "model_service.proto", +}