feat: Add semantic highlight (#46189)

https://github.com/milvus-io/milvus/issues/42589

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Semantic Highlighting Feature

**Core Invariant**: Semantic highlighting operates on a per-field basis
with independent text processing through an external Zilliz highlight
provider. The implementation maintains field ID to field name mapping
and correlates highlight results back to original field outputs.

**What is Added**: This PR introduces semantic highlighting capability
for search results alongside the existing lexical highlighting. The
feature consists of:
- New `SemanticHighlight` orchestrator that validates queries/input
fields against collection schema, instantiates a Zilliz-based provider,
and batches text processing across multiple queries
- New `SemanticHighlighter` proxy wrapper implementing the `Highlighter`
interface for search pipeline integration
- New `semanticHighlightOperator` that processes search results by
delegating per-field text processing to the provider and attaching
correlated `HighlightResult` data to search outputs
- New gRPC service definition (`HighlightService`) and
`ZillizClient.Highlight()` method for external provider communication

**No Data Loss or Regression**: The change is purely additive without
modifying existing logic:
- Lexical highlighting path remains unchanged (separate switch case in
`createHighlightTask`)
- New `HighlightResults` field is only populated when semantic
highlighting is explicitly requested via `HighlightType_Semantic` enum
value
- Gracefully handles missing fields by returning explicit errors rather
than silent failures
- Pipeline operator integration follows existing patterns and only
processes when semantic highlighter is instantiated

**Why This Design**: Semantic highlighting is routed through the same
pipeline operator pattern as lexical highlighting, ensuring consistent
integration into search workflows. The per-field model allows flexible
highlighting across different text columns and batch processing ensures
efficient handling of multiple queries with configurable provider
constraints.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-12-31 11:41:22 +08:00 committed by GitHub
parent 49939f5f2b
commit 1100d8f7e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1911 additions and 51 deletions

View File

@ -14,9 +14,12 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/function/highlight"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
@ -413,3 +416,58 @@ func buildStringFragments(task *highlightTask, idx int, frags []*querypb.Highlig
}
return result
}
type SemanticHighlighter struct {
highlight *highlight.SemanticHighlight
}
func newSemanticHighlighter(t *searchTask, extraInfo *models.ModelExtraInfo) (*SemanticHighlighter, error) {
conf := paramtable.Get().FunctionCfg.ZillizProviders.GetValue()
highlight, err := highlight.NewSemanticHighlight(t.schema.CollectionSchema, t.request.GetHighlighter().GetParams(), conf, extraInfo)
if err != nil {
return nil, err
}
return &SemanticHighlighter{highlight: highlight}, nil
}
func (h *SemanticHighlighter) FieldIDs() []int64 {
return h.highlight.FieldIDs()
}
func (h *SemanticHighlighter) AsSearchPipelineOperator(t *searchTask) (operator, error) {
return &semanticHighlightOperator{highlight: h.highlight}, nil
}
type semanticHighlightOperator struct {
highlight *highlight.SemanticHighlight
}
func (op *semanticHighlightOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults)
datas := result.Results.GetFieldsData()
if len(datas) == 0 {
return []any{result}, nil
}
highlightResults := []*commonpb.HighlightResult{}
for _, fieldID := range op.highlight.FieldIDs() {
fieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldId == fieldID })
if !ok {
return nil, errors.Errorf("get highlight failed, text field not in output field %d", fieldID)
}
texts := fieldDatas.GetScalars().GetStringData().GetData()
highlights, err := op.highlight.Process(ctx, result.Results.GetTopks(), texts, nil)
if err != nil {
return nil, err
}
singeFieldHighlights := &commonpb.HighlightResult{
FieldName: fieldDatas.FieldName,
Datas: make([]*commonpb.HighlightData, 0, len(highlights)),
}
for _, highlight := range highlights {
singeFieldHighlights.Datas = append(singeFieldHighlights.Datas, &commonpb.HighlightData{Fragments: highlight})
}
highlightResults = append(highlightResults, singeFieldHighlights)
}
result.Results.HighlightResults = highlightResults
return []any{result}, nil
}

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/function/highlight"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/internal/util/function/rerank"
"github.com/milvus-io/milvus/internal/util/segcore"
@ -362,6 +363,251 @@ func (s *SearchPipelineSuite) TestHighlightOp() {
s.NoError(err)
}
func (s *SearchPipelineSuite) TestSemanticHighlightOp() {
ctx := context.Background()
// Mock SemanticHighlight methods
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
return [][]string{
{"<em>highlighted</em> text 1"},
{"<em>highlighted</em> text 2"},
{"<em>highlighted</em> text 3"},
}, nil
}).Build()
defer mockProcess.UnPatch()
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).To(func(h *highlight.SemanticHighlight) []int64 {
return []int64{101}
}).Build()
defer mockFieldIDs.UnPatch()
// Create operator
op := &semanticHighlightOperator{
highlight: &highlight.SemanticHighlight{},
}
// Create search results with text data
searchResults := &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 3,
Topks: []int64{3},
FieldsData: []*schemapb.FieldData{
{
FieldId: 101,
FieldName: testVarCharField,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"text 1", "text 2", "text 3"},
},
},
},
},
},
},
},
}
// Run the operator
results, err := op.run(ctx, s.span, searchResults)
s.NoError(err)
s.NotNil(results)
s.Len(results, 1)
// Verify results
result := results[0].(*milvuspb.SearchResults)
s.NotNil(result.Results.HighlightResults)
s.Len(result.Results.HighlightResults, 1)
highlightResult := result.Results.HighlightResults[0]
s.Equal(testVarCharField, highlightResult.FieldName)
s.Len(highlightResult.Datas, 3)
s.Equal([]string{"<em>highlighted</em> text 1"}, highlightResult.Datas[0].Fragments)
s.Equal([]string{"<em>highlighted</em> text 2"}, highlightResult.Datas[1].Fragments)
s.Equal([]string{"<em>highlighted</em> text 3"}, highlightResult.Datas[2].Fragments)
}
func (s *SearchPipelineSuite) TestSemanticHighlightOpMissingField() {
ctx := context.Background()
// Mock FieldIDs to return field 999 (not in results)
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{999}).Build()
defer mockFieldIDs.UnPatch()
op := &semanticHighlightOperator{
highlight: &highlight.SemanticHighlight{},
}
// Create search results without the expected field
searchResults := &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
Topks: []int64{1},
FieldsData: []*schemapb.FieldData{
{
FieldId: 101,
FieldName: testVarCharField,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"text 1"},
},
},
},
},
},
},
},
}
// Run the operator and expect error
_, err := op.run(ctx, s.span, searchResults)
s.Error(err)
s.Contains(err.Error(), "text field not in output field")
}
func (s *SearchPipelineSuite) TestSemanticHighlightOpMultipleFields() {
ctx := context.Background()
// Use a counter to return different results for different calls
callCount := 0
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
callCount++
return [][]string{
{fmt.Sprintf("<em>highlighted</em> text field%d-1", callCount)},
{fmt.Sprintf("<em>highlighted</em> text field%d-2", callCount)},
}, nil
}).Build()
defer mockProcess.UnPatch()
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101, 102}).Build()
defer mockFieldIDs.UnPatch()
op := &semanticHighlightOperator{
highlight: &highlight.SemanticHighlight{},
}
// Create search results with multiple text fields
searchResults := &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 2,
Topks: []int64{2},
FieldsData: []*schemapb.FieldData{
{
FieldId: 101,
FieldName: "field1",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"text 1", "text 2"},
},
},
},
},
},
{
FieldId: 102,
FieldName: "field2",
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"another text 1", "another text 2"},
},
},
},
},
},
},
},
}
// Run the operator
results, err := op.run(ctx, s.span, searchResults)
s.NoError(err)
s.NotNil(results)
// Verify results
result := results[0].(*milvuspb.SearchResults)
s.NotNil(result.Results.HighlightResults)
s.Len(result.Results.HighlightResults, 2)
// Verify first field
s.Equal("field1", result.Results.HighlightResults[0].FieldName)
s.Len(result.Results.HighlightResults[0].Datas, 2)
// Verify second field
s.Equal("field2", result.Results.HighlightResults[1].FieldName)
s.Len(result.Results.HighlightResults[1].Datas, 2)
}
func (s *SearchPipelineSuite) TestSemanticHighlightOpEmptyResults() {
ctx := context.Background()
// Mock Process to return empty results
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
return [][]string{}, nil
}).Build()
defer mockProcess.UnPatch()
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101}).Build()
defer mockFieldIDs.UnPatch()
op := &semanticHighlightOperator{
highlight: &highlight.SemanticHighlight{},
}
// Create empty search results
searchResults := &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 0,
Topks: []int64{0},
FieldsData: []*schemapb.FieldData{
{
FieldId: 101,
FieldName: testVarCharField,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{},
},
},
},
},
},
},
},
}
// Run the operator
results, err := op.run(ctx, s.span, searchResults)
s.NoError(err)
s.NotNil(results)
// Verify results
result := results[0].(*milvuspb.SearchResults)
s.NotNil(result.Results.HighlightResults)
s.Len(result.Results.HighlightResults, 1)
s.Equal(testVarCharField, result.Results.HighlightResults[0].FieldName)
s.Len(result.Results.HighlightResults[0].Datas, 0)
}
func (s *SearchPipelineSuite) TestSearchPipeline() {
collectionName := "test"
task := &searchTask{

View File

@ -618,6 +618,13 @@ func (t *searchTask) addHighlightTask(highlighter *commonpb.Highlighter, metricT
switch highlighter.GetType() {
case commonpb.HighlightType_Lexical:
return t.createLexicalHighlighter(highlighter, metricType, annsField, placeholder, analyzerName)
case commonpb.HighlightType_Semantic:
h, err := newSemanticHighlighter(t, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()})
if err != nil {
return merr.WrapErrParameterInvalidMsg("Create SemanticHighlight failed: %v ", err)
}
t.highlighter = h
return nil
default:
return merr.WrapErrParameterInvalidMsg("unsupported highlight type: %v", highlighter.GetType())
}

View File

@ -17,6 +17,7 @@ package proxy
import (
"context"
"encoding/json"
"fmt"
"math"
"strconv"
@ -43,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/internal/util/function/highlight"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/v2/common"
@ -4920,7 +4922,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
},
}
schemaInfo := newSchemaInfo(schema)
info := newSchemaInfo(schema)
placeholder := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{{
@ -4934,7 +4936,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("lexical highlight success", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -4954,7 +4956,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("Lexical highlight with custom tags", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -4975,7 +4977,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("lexical highlight with wrong metric type", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
SearchRequest: &internalpb.SearchRequest{},
request: &milvuspb.SearchRequest{},
}
@ -4991,7 +4993,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("lexical highlight with invalid pre_tags type", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -5015,9 +5017,9 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
},
}
schemaInfo := newSchemaInfo(schemaWithoutBM25)
info := newSchemaInfo(schemaWithoutBM25)
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -5031,7 +5033,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("highlight without highlight search text", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -5045,7 +5047,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("highlight with invalid highlight search key", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -5059,7 +5061,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
t.Run("highlight with unknown type", func(t *testing.T) {
task := &searchTask{
schema: schemaInfo,
schema: info,
}
highlighter := &commonpb.Highlighter{
@ -5070,4 +5072,23 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
assert.Error(t, err)
})
t.Run("semantic highlight success", func(t *testing.T) {
task := &searchTask{
schema: info,
}
queriesJSON, _ := json.Marshal([]string{"test_query"})
inputFieldsJSON, _ := json.Marshal([]string{"text_field"})
highlighter := &commonpb.Highlighter{
Type: commonpb.HighlightType_Semantic,
Params: []*commonpb.KeyValuePair{{Key: "queries", Value: string(queriesJSON)}, {Key: "input_fields", Value: string(inputFieldsJSON)}},
}
mockSemanticHighlight := mockey.Mock(highlight.NewSemanticHighlight).Return(&highlight.SemanticHighlight{}, nil).Build()
defer mockSemanticHighlight.UnPatch()
task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
require.NotNil(t, task.highlighter)
})
}

View File

@ -0,0 +1,146 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package highlight
import (
"context"
"encoding/json"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models"
)
type semanticHighlightProvider interface {
highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error)
maxBatch() int
}
type baseSemanticHighlightProvider struct {
batchSize int
}
func (provider *baseSemanticHighlightProvider) maxBatch() int {
return provider.batchSize
}
type SemanticHighlight struct {
fieldIDs []int64
provider semanticHighlightProvider
queries []string
}
const (
queryKeyName string = "queries"
inputFieldKeyName string = "input_fields"
)
func NewSemanticHighlight(collSchema *schemapb.CollectionSchema, params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (*SemanticHighlight, error) {
queries := []string{}
inputField := []string{}
for _, param := range params {
switch param.Key {
case queryKeyName:
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
return nil, fmt.Errorf("Parse queries failed, err: %v", err)
}
case inputFieldKeyName:
if err := json.Unmarshal([]byte(param.Value), &inputField); err != nil {
return nil, fmt.Errorf("Parse input_field failed, err: %v", err)
}
}
}
if len(queries) == 0 {
return nil, fmt.Errorf("queries is required")
}
if len(inputField) == 0 {
return nil, fmt.Errorf("input_field is required")
}
fieldIDMap := make(map[string]*schemapb.FieldSchema)
for _, field := range collSchema.Fields {
fieldIDMap[field.Name] = field
}
fieldIDs := []int64{}
for _, fieldName := range inputField {
field, ok := fieldIDMap[fieldName]
if !ok {
return nil, fmt.Errorf("input_field %s not found", fieldName)
}
if field.DataType != schemapb.DataType_VarChar && field.DataType != schemapb.DataType_Text {
return nil, fmt.Errorf("input_field %s is not a VarChar or Text field", fieldName)
}
fieldIDs = append(fieldIDs, field.FieldID)
}
// TODO: support other providers if have more providers
provider, err := newZillizHighlightProvider(params, conf, extraInfo)
if err != nil {
return nil, err
}
return &SemanticHighlight{fieldIDs: fieldIDs, provider: provider, queries: queries}, nil
}
func (highlight *SemanticHighlight) FieldIDs() []int64 {
return highlight.fieldIDs
}
func (highlight *SemanticHighlight) processOneQuery(ctx context.Context, query string, data []string, params map[string]string) ([][]string, error) {
if len(data) == 0 {
return [][]string{}, nil
}
highlights, err := highlight.provider.highlight(ctx, query, data, params)
if err != nil {
return nil, err
}
if len(highlights) != len(data) {
return nil, fmt.Errorf("Highlights size must equal to data size, but got highlights size [%d], data size [%d]", len(highlights), len(data))
}
return highlights, nil
}
func (highlight *SemanticHighlight) Process(ctx context.Context, topks []int64, data []string, params map[string]string) ([][]string, error) {
nq := len(topks)
if len(highlight.queries) != nq {
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", nq, len(highlight.queries), highlight.queries)
}
if len(data) == 0 {
return [][]string{}, nil
}
highlights := make([][]string, 0, len(data))
start := int64(0)
for i, query := range highlight.queries {
size := topks[i]
singleHighlights, err := highlight.processOneQuery(ctx, query, data[start:start+size], params)
if err != nil {
return nil, err
}
highlights = append(highlights, singleHighlights...)
start += size
}
return highlights, nil
}

View File

@ -0,0 +1,545 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package highlight
import (
"context"
"encoding/json"
"testing"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
)
func TestSemanticHighlight(t *testing.T) {
suite.Run(t, new(SemanticHighlightSuite))
}
type SemanticHighlightSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
}
func (s *SemanticHighlightSuite) SetupTest() {
s.schema = &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "title", DataType: schemapb.DataType_VarChar},
{FieldID: 102, Name: "content", DataType: schemapb.DataType_Text},
{FieldID: 103, Name: "description", DataType: schemapb.DataType_VarChar},
{FieldID: 104, Name: "embedding", DataType: schemapb.DataType_FloatVector},
},
}
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_Success() {
queries := []string{"machine learning", "artificial intelligence"}
inputFields := []string{"title", "content"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
s.NotNil(highlight)
s.Equal([]int64{101, 102}, highlight.FieldIDs())
s.Equal(queries, highlight.queries)
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingQueries() {
inputFields := []string{"title"}
inputFieldsJSON, _ := json.Marshal(inputFields)
params := []*commonpb.KeyValuePair{
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "queries is required")
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingInputFields() {
queries := []string{"machine learning"}
queriesJSON, _ := json.Marshal(queries)
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "input_field is required")
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidQueriesJSON() {
inputFields := []string{"title"}
inputFieldsJSON, _ := json.Marshal(inputFields)
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: "invalid json"},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "Parse queries failed")
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidInputFieldsJSON() {
queries := []string{"machine learning"}
queriesJSON, _ := json.Marshal(queries)
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: "invalid json"},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "Parse input_field failed")
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_FieldNotFound() {
queries := []string{"machine learning"}
inputFields := []string{"nonexistent_field"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "not found")
}
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidFieldType() {
queries := []string{"machine learning"}
inputFields := []string{"embedding"} // FloatVector, not VarChar or Text
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.Error(err)
s.Nil(highlight)
s.Contains(err.Error(), "is not a VarChar or Text field")
}
func (s *SemanticHighlightSuite) TestProcessOneQuery_Success() {
queries := []string{"machine learning"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
expectedHighlights := [][]string{
{"machine learning"},
{"machine"},
}
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
return expectedHighlights, nil
}).Build()
defer mock2.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{"Machine learning is a subset of AI", "Machine learning is powerful"}
highlights, err := highlight.processOneQuery(ctx, "machine learning", data, nil)
s.NoError(err)
s.Equal(expectedHighlights, highlights)
}
func (s *SemanticHighlightSuite) TestProcessOneQuery_Error() {
queries := []string{"test query"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
expectedError := errors.New("highlight service error")
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
return nil, expectedError
}).Build()
defer mock2.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{"test document"}
highlights, err := highlight.processOneQuery(ctx, "test query", data, nil)
s.Error(err)
s.Nil(highlights)
s.Equal(expectedError, err)
}
func (s *SemanticHighlightSuite) TestProcess_Success() {
queries := []string{"machine learning", "deep learning"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
expectedHighlights1 := [][]string{
{"machine learning", "deep learning"},
}
expectedHighlights2 := [][]string{
{"deep learning", "machine learning"},
}
callCount := 0
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, query string, _ []string, _ map[string]string) ([][]string, error) {
callCount++
if query == "machine learning" {
return expectedHighlights1, nil
}
return expectedHighlights2, nil
}).Build()
defer mock2.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{"Machine learning document", "Deep learning document"}
highlights, err := highlight.Process(ctx, []int64{1, 1}, data, nil)
s.NoError(err)
s.NotNil(highlights)
s.Equal(2, callCount, "Should call highlight twice for two queries")
}
func (s *SemanticHighlightSuite) TestProcess_NqMismatch() {
queries := []string{"machine learning"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{"test document"}
highlights, err := highlight.Process(ctx, []int64{1, 1, 1}, data, nil) // nq=3 but queries has only 1
s.Error(err)
s.Nil(highlights)
s.Contains(err.Error(), "nq must equal to queries size")
}
func (s *SemanticHighlightSuite) TestProcess_ProviderError() {
queries := []string{"test query"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
expectedError := errors.New("provider error")
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
return nil, expectedError
}).Build()
defer mock2.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{"test document"}
highlights, err := highlight.Process(ctx, []int64{1}, data, nil)
s.Error(err)
s.Nil(highlights)
s.Equal(expectedError, err)
}
func (s *SemanticHighlightSuite) TestProcess_EmptyData() {
queries := []string{"test query", "test query 2", "test query 3"}
inputFields := []string{"title"}
queriesJSON, _ := json.Marshal(queries)
inputFieldsJSON, _ := json.Marshal(inputFields)
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, texts []string, _ map[string]string) ([][]string, error) {
return [][]string{texts}, nil
}).Build()
defer mock2.UnPatch()
params := []*commonpb.KeyValuePair{
{Key: queryKeyName, Value: string(queriesJSON)},
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
s.NoError(err)
ctx := context.Background()
data := []string{}
highlights, err := highlight.Process(ctx, []int64{0, 0, 0}, data, nil)
s.NoError(err)
s.NotNil(highlights)
data2 := []string{"test document"}
highlights2, err := highlight.Process(ctx, []int64{0, 1, 0}, data2, nil)
s.NoError(err)
s.Equal(1, len(highlights2))
s.Equal([][]string{{"test document"}}, highlights2)
}
func (s *SemanticHighlightSuite) TestBaseSemanticHighlightProvider_MaxBatch() {
provider := &baseSemanticHighlightProvider{batchSize: 128}
s.Equal(128, provider.maxBatch())
provider2 := &baseSemanticHighlightProvider{batchSize: 32}
s.Equal(32, provider2.maxBatch())
}

View File

@ -0,0 +1,78 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package highlight
import (
"context"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
)
type zillizHighlightProvider struct {
baseSemanticHighlightProvider
client *zilliz.ZillizClient
modelName string
queries []string
modelParams map[string]string
}
func newZillizHighlightProvider(params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (semanticHighlightProvider, error) {
var modelDeploymentID string
var err error
maxBatch := 64
modelParams := map[string]string{}
for _, param := range params {
switch strings.ToLower(param.Key) {
case models.ModelDeploymentIDKey:
modelDeploymentID = param.Value
case models.MaxClientBatchSizeParamKey:
if maxBatch, err = strconv.Atoi(param.Value); err != nil {
return nil, err
}
default:
modelParams[param.Key] = param.Value
}
}
c, err := zilliz.NewZilliClient(modelDeploymentID, extraInfo.ClusterID, extraInfo.DBName, conf)
if err != nil {
return nil, err
}
provider := zillizHighlightProvider{
baseSemanticHighlightProvider: baseSemanticHighlightProvider{batchSize: maxBatch},
client: c,
modelParams: modelParams,
}
return &provider, nil
}
func (h *zillizHighlightProvider) highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) {
highlights, err := h.client.Highlight(ctx, query, texts, params)
if err != nil {
return nil, err
}
return highlights, nil
}

View File

@ -0,0 +1,240 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package highlight
import (
"context"
"testing"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/util/function/models"
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
)
func TestZillizHighlightProvider(t *testing.T) {
suite.Run(t, new(ZillizHighlightProviderSuite))
}
type ZillizHighlightProviderSuite struct {
suite.Suite
}
func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_Success() {
tests := []struct {
name string
params []*commonpb.KeyValuePair
conf map[string]string
extraInfo *models.ModelExtraInfo
expectedBatch int
expectedParams map[string]string
}{
{
name: "basic configuration",
params: []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
},
conf: map[string]string{
"endpoint": "localhost:8080",
},
extraInfo: &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
},
expectedBatch: 64, // default batch size
expectedParams: map[string]string{},
},
{
name: "with custom batch size",
params: []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
{Key: models.MaxClientBatchSizeParamKey, Value: "32"},
},
conf: map[string]string{
"endpoint": "localhost:8080",
},
extraInfo: &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
},
expectedBatch: 32,
expectedParams: map[string]string{},
},
{
name: "with additional model parameters",
params: []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
{Key: models.MaxClientBatchSizeParamKey, Value: "16"},
{Key: "threshold", Value: "0.7"},
},
conf: map[string]string{
"endpoint": "localhost:8080",
},
extraInfo: &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
},
expectedBatch: 16,
expectedParams: map[string]string{
"threshold": "0.7",
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// Note: This test will fail with actual connection since we can't easily mock the global client manager
// But we can test the parameter parsing logic by checking the error message
provider, err := newZillizHighlightProvider(tt.params, tt.conf, tt.extraInfo)
// Since we can't easily mock the zilliz client creation, we expect a connection error
// but we can verify that the parameters were parsed correctly by checking the error doesn't relate to parameter parsing
if err != nil {
// Connection errors are expected in unit tests
s.Contains(err.Error(), "Connect model serving failed", "Expected connection error, got: %v", err)
} else {
// If somehow the connection succeeds, verify the provider was created correctly
s.NotNil(provider)
zillizProvider, ok := provider.(*zillizHighlightProvider)
s.True(ok)
s.Equal(tt.expectedBatch, zillizProvider.maxBatch())
s.Equal(tt.expectedParams, zillizProvider.modelParams)
}
})
}
}
func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_InvalidBatchSize() {
tests := []struct {
name string
batchValue string
expectedErr string
}{
{
name: "invalid number format",
batchValue: "not_a_number",
expectedErr: "invalid syntax",
},
{
name: "empty batch size",
batchValue: "",
expectedErr: "invalid syntax",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
params := []*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
{Key: models.MaxClientBatchSizeParamKey, Value: tt.batchValue},
}
conf := map[string]string{
"endpoint": "localhost:8080",
}
extraInfo := &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
}
provider, err := newZillizHighlightProvider(params, conf, extraInfo)
s.Error(err)
s.Nil(provider)
s.Contains(err.Error(), tt.expectedErr)
})
}
}
func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Success() {
// Set up expected behavior
ctx := context.Background()
query := "machine learning"
texts := []string{
"Machine learning is a subset of artificial intelligence.",
"Deep learning is a type of machine learning.",
"Natural language processing uses machine learning techniques.",
}
params := map[string]string{"param1": "value1"}
expectedHighlights := [][]string{
{"Machine learning", "artificial intelligence"},
{"Deep learning", "machine learning"},
{"Natural language processing", "machine learning"},
}
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
return expectedHighlights, nil
}).Build()
defer mock2.UnPatch()
provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
})
s.NoError(err)
s.NotNil(provider)
highlights, err := provider.highlight(ctx, query, texts, params)
s.NoError(err)
s.Equal(expectedHighlights, highlights)
}
func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Error() {
// Set up expected behavior with error
ctx := context.Background()
query := "test query"
texts := []string{"doc1", "doc2", "doc3"}
params := map[string]string{"param1": "value1"}
expectedError := errors.New("highlight service error")
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
return &zilliz.ZillizClient{}, nil
}).Build()
defer mock1.UnPatch()
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
return nil, expectedError
}).Build()
defer mock2.UnPatch()
provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
}, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{
ClusterID: "test-cluster",
DBName: "test-db",
})
s.NoError(err)
s.NotNil(provider)
// Test the highlight method
highlights, err := provider.highlight(ctx, query, texts, params)
s.Error(err)
s.Nil(highlights)
s.Equal(expectedError, err)
}

View File

@ -25,6 +25,7 @@ import (
"sync"
"time"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/connectivity"
@ -33,6 +34,7 @@ import (
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/modelservicepb"
)
@ -93,7 +95,10 @@ func (m *clientManager) GetConn(clientConf *clientConfig) (*grpc.ClientConn, err
if m.conn != nil {
if clientConf.endpoint != m.config.endpoint {
_ = m.conn.Close()
err := m.conn.Close()
if err != nil {
log.Warn("Close connect failed", zap.String("endpoint", m.config.endpoint), zap.Error(err))
}
m.conn = nil
} else {
state := m.conn.GetState()
@ -247,3 +252,22 @@ func (c *ZillizClient) Rerank(ctx context.Context, query string, texts []string,
}
return res.Scores, nil
}
func (c *ZillizClient) Highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) {
stub := modelservicepb.NewHighlightServiceClient(c.conn)
req := &modelservicepb.HighlightRequest{
Query: query,
Documents: texts,
Params: params,
}
ctx = c.setMeta(ctx)
res, err := stub.Highlight(ctx, req)
if err != nil {
return nil, err
}
highlights := make([][]string, 0, len(res.GetResults()))
for _, ret := range res.GetResults() {
highlights = append(highlights, ret.GetSentences())
}
return highlights, nil
}

View File

@ -63,6 +63,19 @@ func (m *mockRerankServer) Rerank(ctx context.Context, req *modelservicepb.TextR
return m.response, nil
}
type mockHighlightServer struct {
modelservicepb.UnimplementedHighlightServiceServer
response *modelservicepb.HighlightResponse
err error
}
func (m *mockHighlightServer) Highlight(ctx context.Context, req *modelservicepb.HighlightRequest) (*modelservicepb.HighlightResponse, error) {
if m.err != nil {
return nil, m.err
}
return m.response, nil
}
func TestLoadConfig(t *testing.T) {
tests := []struct {
name string
@ -559,3 +572,103 @@ func TestZillizClient_Embedding_EmptyResponse(t *testing.T) {
assert.NoError(t, err)
assert.Empty(t, embeddings)
}
func TestZillizClient_Highlight(t *testing.T) {
// Setup mock server
s, lis, dialer := setupMockServer(t)
defer lis.Close()
defer s.Stop()
mockServer := &mockHighlightServer{
response: &modelservicepb.HighlightResponse{
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
Results: []*modelservicepb.HighlightResult{
{
Sentences: []string{"highlight1", "highlight2"},
},
{
Sentences: []string{"highlight3", "highlight4"},
},
},
},
}
modelservicepb.RegisterHighlightServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
fmt.Printf("Server exited with error: %v\n", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test successful highlight
ctx := context.Background()
query := "test query"
texts := []string{"doc1", "doc2", "doc3"}
params := map[string]string{"param1": "value1"}
highlights, err := client.Highlight(ctx, query, texts, params)
assert.NoError(t, err)
assert.Equal(t, [][]string{{"highlight1", "highlight2"}, {"highlight3", "highlight4"}}, highlights)
}
func TestZillizClient_Highlight_Error(t *testing.T) {
// Setup mock server with error
s, lis, dialer := setupMockServer(t)
defer lis.Close()
defer s.Stop()
mockServer := &mockHighlightServer{
err: assert.AnError,
}
modelservicepb.RegisterHighlightServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
fmt.Printf("Server exited with error: %v\n", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test highlight with error
ctx := context.Background()
query := "test query"
texts := []string{"doc1", "doc2", "doc3"}
params := map[string]string{"param1": "value1"}
highlights, err := client.Highlight(ctx, query, texts, params)
assert.Error(t, err)
assert.Nil(t, highlights)
}

View File

@ -11,6 +11,10 @@ service RerankService {
rpc Rerank(TextRerankRequest) returns (TextRerankResponse);
}
service HighlightService {
rpc Highlight(HighlightRequest) returns (HighlightResponse);
}
message Status {
int32 code = 1;
string msg = 2;
@ -68,3 +72,20 @@ message TextRerankResponse {
map<string, string> extra_info = 2;
repeated float scores = 3;
}
message HighlightRequest {
string model = 1;
string query = 2;
repeated string documents = 3;
map<string, string> params = 4;
}
message HighlightResult {
repeated string sentences = 1;
}
message HighlightResponse {
Status status = 1;
map<string, string> extra_info = 2;
repeated HighlightResult results = 3;
}

View File

@ -620,6 +620,187 @@ func (x *TextRerankResponse) GetScores() []float32 {
return nil
}
type HighlightRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Model string `protobuf:"bytes,1,opt,name=model,proto3" json:"model,omitempty"`
Query string `protobuf:"bytes,2,opt,name=query,proto3" json:"query,omitempty"`
Documents []string `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"`
Params map[string]string `protobuf:"bytes,4,rep,name=params,proto3" json:"params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
}
func (x *HighlightRequest) Reset() {
*x = HighlightRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_model_service_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HighlightRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HighlightRequest) ProtoMessage() {}
func (x *HighlightRequest) ProtoReflect() protoreflect.Message {
mi := &file_model_service_proto_msgTypes[9]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HighlightRequest.ProtoReflect.Descriptor instead.
func (*HighlightRequest) Descriptor() ([]byte, []int) {
return file_model_service_proto_rawDescGZIP(), []int{9}
}
func (x *HighlightRequest) GetModel() string {
if x != nil {
return x.Model
}
return ""
}
func (x *HighlightRequest) GetQuery() string {
if x != nil {
return x.Query
}
return ""
}
func (x *HighlightRequest) GetDocuments() []string {
if x != nil {
return x.Documents
}
return nil
}
func (x *HighlightRequest) GetParams() map[string]string {
if x != nil {
return x.Params
}
return nil
}
type HighlightResult struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Sentences []string `protobuf:"bytes,1,rep,name=sentences,proto3" json:"sentences,omitempty"`
}
func (x *HighlightResult) Reset() {
*x = HighlightResult{}
if protoimpl.UnsafeEnabled {
mi := &file_model_service_proto_msgTypes[10]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HighlightResult) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HighlightResult) ProtoMessage() {}
func (x *HighlightResult) ProtoReflect() protoreflect.Message {
mi := &file_model_service_proto_msgTypes[10]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HighlightResult.ProtoReflect.Descriptor instead.
func (*HighlightResult) Descriptor() ([]byte, []int) {
return file_model_service_proto_rawDescGZIP(), []int{10}
}
func (x *HighlightResult) GetSentences() []string {
if x != nil {
return x.Sentences
}
return nil
}
type HighlightResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Status *Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"`
ExtraInfo map[string]string `protobuf:"bytes,2,rep,name=extra_info,json=extraInfo,proto3" json:"extra_info,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
Results []*HighlightResult `protobuf:"bytes,3,rep,name=results,proto3" json:"results,omitempty"`
}
func (x *HighlightResponse) Reset() {
*x = HighlightResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_model_service_proto_msgTypes[11]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HighlightResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HighlightResponse) ProtoMessage() {}
func (x *HighlightResponse) ProtoReflect() protoreflect.Message {
mi := &file_model_service_proto_msgTypes[11]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HighlightResponse.ProtoReflect.Descriptor instead.
func (*HighlightResponse) Descriptor() ([]byte, []int) {
return file_model_service_proto_rawDescGZIP(), []int{11}
}
func (x *HighlightResponse) GetStatus() *Status {
if x != nil {
return x.Status
}
return nil
}
func (x *HighlightResponse) GetExtraInfo() map[string]string {
if x != nil {
return x.ExtraInfo
}
return nil
}
func (x *HighlightResponse) GetResults() []*HighlightResult {
if x != nil {
return x.Results
}
return nil
}
var File_model_service_proto protoreflect.FileDescriptor
var file_model_service_proto_rawDesc = []byte{
@ -729,27 +910,72 @@ var file_model_service_proto_rawDesc = []byte{
0x78, 0x74, 0x72, 0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65,
0x78, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12,
0x2f, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d,
0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74,
0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x30, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78,
0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x32, 0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e,
0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64,
0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65,
0x72, 0x61, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xe8, 0x01, 0x0a, 0x10, 0x48, 0x69,
0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14,
0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d,
0x6f, 0x64, 0x65, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x6f,
0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x64,
0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x4f, 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61,
0x6d, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72,
0x79, 0x52, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x39, 0x0a, 0x0b, 0x50, 0x61, 0x72,
0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x3a, 0x02, 0x38, 0x01, 0x22, 0x2f, 0x0a, 0x0f, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68,
0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x65, 0x6e, 0x74, 0x65,
0x6e, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x6e, 0x74,
0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0xae, 0x02, 0x0a, 0x11, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69,
0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x06, 0x73,
0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x6d, 0x69,
0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c,
0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61,
0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d,
0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06,
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x5a, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f,
0x69, 0x6e, 0x66, 0x6f, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x3b, 0x2e, 0x6d, 0x69, 0x6c,
0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x45, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e,
0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e,
0x66, 0x6f, 0x12, 0x44, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20,
0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e,
0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52,
0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x1a, 0x3c, 0x0a, 0x0e, 0x45, 0x78, 0x74, 0x72,
0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c,
0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65, 0x78, 0x74, 0x45,
0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x2f, 0x2e, 0x6d,
0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65,
0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d, 0x62,
0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e,
0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64,
0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d,
0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32,
0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e, 0x6d, 0x69, 0x6c,
0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e,
0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x7a, 0x0a, 0x10, 0x48, 0x69, 0x67, 0x68, 0x6c,
0x69, 0x67, 0x68, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x66, 0x0a, 0x09, 0x48,
0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x12, 0x2b, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2c, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76,
0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f,
0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -765,7 +991,7 @@ func file_model_service_proto_rawDescGZIP() []byte {
}
var file_model_service_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 13)
var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 18)
var file_model_service_proto_goTypes = []interface{}{
(DenseVector_DType)(0), // 0: milvus.proto.modelservice.DenseVector.DType
(*Status)(nil), // 1: milvus.proto.modelservice.Status
@ -777,33 +1003,44 @@ var file_model_service_proto_goTypes = []interface{}{
(*TextEmbeddingResponse)(nil), // 7: milvus.proto.modelservice.TextEmbeddingResponse
(*TextRerankRequest)(nil), // 8: milvus.proto.modelservice.TextRerankRequest
(*TextRerankResponse)(nil), // 9: milvus.proto.modelservice.TextRerankResponse
nil, // 10: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
nil, // 11: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
nil, // 12: milvus.proto.modelservice.TextRerankRequest.ParamsEntry
nil, // 13: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
(*HighlightRequest)(nil), // 10: milvus.proto.modelservice.HighlightRequest
(*HighlightResult)(nil), // 11: milvus.proto.modelservice.HighlightResult
(*HighlightResponse)(nil), // 12: milvus.proto.modelservice.HighlightResponse
nil, // 13: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
nil, // 14: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
nil, // 15: milvus.proto.modelservice.TextRerankRequest.ParamsEntry
nil, // 16: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
nil, // 17: milvus.proto.modelservice.HighlightRequest.ParamsEntry
nil, // 18: milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry
}
var file_model_service_proto_depIdxs = []int32{
0, // 0: milvus.proto.modelservice.DenseVector.dtype:type_name -> milvus.proto.modelservice.DenseVector.DType
2, // 1: milvus.proto.modelservice.MultiVector.token_vectors:type_name -> milvus.proto.modelservice.DenseVector
10, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
13, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
2, // 3: milvus.proto.modelservice.EmbeddingResult.dense:type_name -> milvus.proto.modelservice.DenseVector
3, // 4: milvus.proto.modelservice.EmbeddingResult.sparse:type_name -> milvus.proto.modelservice.SparseVector
4, // 5: milvus.proto.modelservice.EmbeddingResult.multi_vector:type_name -> milvus.proto.modelservice.MultiVector
1, // 6: milvus.proto.modelservice.TextEmbeddingResponse.status:type_name -> milvus.proto.modelservice.Status
11, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
14, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
6, // 8: milvus.proto.modelservice.TextEmbeddingResponse.results:type_name -> milvus.proto.modelservice.EmbeddingResult
12, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry
15, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry
1, // 10: milvus.proto.modelservice.TextRerankResponse.status:type_name -> milvus.proto.modelservice.Status
13, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
5, // 12: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest
8, // 13: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest
7, // 14: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse
9, // 15: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse
14, // [14:16] is the sub-list for method output_type
12, // [12:14] is the sub-list for method input_type
12, // [12:12] is the sub-list for extension type_name
12, // [12:12] is the sub-list for extension extendee
0, // [0:12] is the sub-list for field type_name
16, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
17, // 12: milvus.proto.modelservice.HighlightRequest.params:type_name -> milvus.proto.modelservice.HighlightRequest.ParamsEntry
1, // 13: milvus.proto.modelservice.HighlightResponse.status:type_name -> milvus.proto.modelservice.Status
18, // 14: milvus.proto.modelservice.HighlightResponse.extra_info:type_name -> milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry
11, // 15: milvus.proto.modelservice.HighlightResponse.results:type_name -> milvus.proto.modelservice.HighlightResult
5, // 16: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest
8, // 17: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest
10, // 18: milvus.proto.modelservice.HighlightService.Highlight:input_type -> milvus.proto.modelservice.HighlightRequest
7, // 19: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse
9, // 20: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse
12, // 21: milvus.proto.modelservice.HighlightService.Highlight:output_type -> milvus.proto.modelservice.HighlightResponse
19, // [19:22] is the sub-list for method output_type
16, // [16:19] is the sub-list for method input_type
16, // [16:16] is the sub-list for extension type_name
16, // [16:16] is the sub-list for extension extendee
0, // [0:16] is the sub-list for field type_name
}
func init() { file_model_service_proto_init() }
@ -920,6 +1157,42 @@ func file_model_service_proto_init() {
return nil
}
}
file_model_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HighlightRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_model_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HighlightResult); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_model_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HighlightResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -927,9 +1200,9 @@ func file_model_service_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_model_service_proto_rawDesc,
NumEnums: 1,
NumMessages: 13,
NumMessages: 18,
NumExtensions: 0,
NumServices: 2,
NumServices: 3,
},
GoTypes: file_model_service_proto_goTypes,
DependencyIndexes: file_model_service_proto_depIdxs,

View File

@ -193,3 +193,91 @@ var RerankService_ServiceDesc = grpc.ServiceDesc{
Streams: []grpc.StreamDesc{},
Metadata: "model_service.proto",
}
const (
HighlightService_Highlight_FullMethodName = "/milvus.proto.modelservice.HighlightService/Highlight"
)
// HighlightServiceClient is the client API for HighlightService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type HighlightServiceClient interface {
Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error)
}
type highlightServiceClient struct {
cc grpc.ClientConnInterface
}
func NewHighlightServiceClient(cc grpc.ClientConnInterface) HighlightServiceClient {
return &highlightServiceClient{cc}
}
func (c *highlightServiceClient) Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error) {
out := new(HighlightResponse)
err := c.cc.Invoke(ctx, HighlightService_Highlight_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// HighlightServiceServer is the server API for HighlightService service.
// All implementations should embed UnimplementedHighlightServiceServer
// for forward compatibility
type HighlightServiceServer interface {
Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error)
}
// UnimplementedHighlightServiceServer should be embedded to have forward compatible implementations.
type UnimplementedHighlightServiceServer struct {
}
func (UnimplementedHighlightServiceServer) Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Highlight not implemented")
}
// UnsafeHighlightServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to HighlightServiceServer will
// result in compilation errors.
type UnsafeHighlightServiceServer interface {
mustEmbedUnimplementedHighlightServiceServer()
}
func RegisterHighlightServiceServer(s grpc.ServiceRegistrar, srv HighlightServiceServer) {
s.RegisterService(&HighlightService_ServiceDesc, srv)
}
func _HighlightService_Highlight_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HighlightRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HighlightServiceServer).Highlight(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HighlightService_Highlight_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HighlightServiceServer).Highlight(ctx, req.(*HighlightRequest))
}
return interceptor(ctx, in, info, handler)
}
// HighlightService_ServiceDesc is the grpc.ServiceDesc for HighlightService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var HighlightService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "milvus.proto.modelservice.HighlightService",
HandlerType: (*HighlightServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Highlight",
Handler: _HighlightService_Highlight_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "model_service.proto",
}