mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: Add semantic highlight (#46189)
https://github.com/milvus-io/milvus/issues/42589 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Semantic Highlighting Feature **Core Invariant**: Semantic highlighting operates on a per-field basis with independent text processing through an external Zilliz highlight provider. The implementation maintains field ID to field name mapping and correlates highlight results back to original field outputs. **What is Added**: This PR introduces semantic highlighting capability for search results alongside the existing lexical highlighting. The feature consists of: - New `SemanticHighlight` orchestrator that validates queries/input fields against collection schema, instantiates a Zilliz-based provider, and batches text processing across multiple queries - New `SemanticHighlighter` proxy wrapper implementing the `Highlighter` interface for search pipeline integration - New `semanticHighlightOperator` that processes search results by delegating per-field text processing to the provider and attaching correlated `HighlightResult` data to search outputs - New gRPC service definition (`HighlightService`) and `ZillizClient.Highlight()` method for external provider communication **No Data Loss or Regression**: The change is purely additive without modifying existing logic: - Lexical highlighting path remains unchanged (separate switch case in `createHighlightTask`) - New `HighlightResults` field is only populated when semantic highlighting is explicitly requested via `HighlightType_Semantic` enum value - Gracefully handles missing fields by returning explicit errors rather than silent failures - Pipeline operator integration follows existing patterns and only processes when semantic highlighter is instantiated **Why This Design**: Semantic highlighting is routed through the same pipeline operator pattern as lexical highlighting, ensuring consistent integration into search workflows. The per-field model allows flexible highlighting across different text columns and batch processing ensures efficient handling of multiple queries with configurable provider constraints. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
49939f5f2b
commit
1100d8f7e2
@ -14,9 +14,12 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/function/highlight"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -413,3 +416,58 @@ func buildStringFragments(task *highlightTask, idx int, frags []*querypb.Highlig
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type SemanticHighlighter struct {
|
||||
highlight *highlight.SemanticHighlight
|
||||
}
|
||||
|
||||
func newSemanticHighlighter(t *searchTask, extraInfo *models.ModelExtraInfo) (*SemanticHighlighter, error) {
|
||||
conf := paramtable.Get().FunctionCfg.ZillizProviders.GetValue()
|
||||
highlight, err := highlight.NewSemanticHighlight(t.schema.CollectionSchema, t.request.GetHighlighter().GetParams(), conf, extraInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SemanticHighlighter{highlight: highlight}, nil
|
||||
}
|
||||
|
||||
func (h *SemanticHighlighter) FieldIDs() []int64 {
|
||||
return h.highlight.FieldIDs()
|
||||
}
|
||||
|
||||
func (h *SemanticHighlighter) AsSearchPipelineOperator(t *searchTask) (operator, error) {
|
||||
return &semanticHighlightOperator{highlight: h.highlight}, nil
|
||||
}
|
||||
|
||||
type semanticHighlightOperator struct {
|
||||
highlight *highlight.SemanticHighlight
|
||||
}
|
||||
|
||||
func (op *semanticHighlightOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
result := inputs[0].(*milvuspb.SearchResults)
|
||||
datas := result.Results.GetFieldsData()
|
||||
if len(datas) == 0 {
|
||||
return []any{result}, nil
|
||||
}
|
||||
highlightResults := []*commonpb.HighlightResult{}
|
||||
for _, fieldID := range op.highlight.FieldIDs() {
|
||||
fieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldId == fieldID })
|
||||
if !ok {
|
||||
return nil, errors.Errorf("get highlight failed, text field not in output field %d", fieldID)
|
||||
}
|
||||
texts := fieldDatas.GetScalars().GetStringData().GetData()
|
||||
highlights, err := op.highlight.Process(ctx, result.Results.GetTopks(), texts, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
singeFieldHighlights := &commonpb.HighlightResult{
|
||||
FieldName: fieldDatas.FieldName,
|
||||
Datas: make([]*commonpb.HighlightData, 0, len(highlights)),
|
||||
}
|
||||
for _, highlight := range highlights {
|
||||
singeFieldHighlights.Datas = append(singeFieldHighlights.Datas, &commonpb.HighlightData{Fragments: highlight})
|
||||
}
|
||||
highlightResults = append(highlightResults, singeFieldHighlights)
|
||||
}
|
||||
result.Results.HighlightResults = highlightResults
|
||||
return []any{result}, nil
|
||||
}
|
||||
|
||||
@ -34,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/function/highlight"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||
@ -362,6 +363,251 @@ func (s *SearchPipelineSuite) TestHighlightOp() {
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSemanticHighlightOp() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock SemanticHighlight methods
|
||||
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
|
||||
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
|
||||
return [][]string{
|
||||
{"<em>highlighted</em> text 1"},
|
||||
{"<em>highlighted</em> text 2"},
|
||||
{"<em>highlighted</em> text 3"},
|
||||
}, nil
|
||||
}).Build()
|
||||
defer mockProcess.UnPatch()
|
||||
|
||||
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).To(func(h *highlight.SemanticHighlight) []int64 {
|
||||
return []int64{101}
|
||||
}).Build()
|
||||
defer mockFieldIDs.UnPatch()
|
||||
|
||||
// Create operator
|
||||
op := &semanticHighlightOperator{
|
||||
highlight: &highlight.SemanticHighlight{},
|
||||
}
|
||||
|
||||
// Create search results with text data
|
||||
searchResults := &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 3,
|
||||
Topks: []int64{3},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 101,
|
||||
FieldName: testVarCharField,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"text 1", "text 2", "text 3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
results, err := op.run(ctx, s.span, searchResults)
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
s.Len(results, 1)
|
||||
|
||||
// Verify results
|
||||
result := results[0].(*milvuspb.SearchResults)
|
||||
s.NotNil(result.Results.HighlightResults)
|
||||
s.Len(result.Results.HighlightResults, 1)
|
||||
|
||||
highlightResult := result.Results.HighlightResults[0]
|
||||
s.Equal(testVarCharField, highlightResult.FieldName)
|
||||
s.Len(highlightResult.Datas, 3)
|
||||
s.Equal([]string{"<em>highlighted</em> text 1"}, highlightResult.Datas[0].Fragments)
|
||||
s.Equal([]string{"<em>highlighted</em> text 2"}, highlightResult.Datas[1].Fragments)
|
||||
s.Equal([]string{"<em>highlighted</em> text 3"}, highlightResult.Datas[2].Fragments)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSemanticHighlightOpMissingField() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock FieldIDs to return field 999 (not in results)
|
||||
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{999}).Build()
|
||||
defer mockFieldIDs.UnPatch()
|
||||
|
||||
op := &semanticHighlightOperator{
|
||||
highlight: &highlight.SemanticHighlight{},
|
||||
}
|
||||
|
||||
// Create search results without the expected field
|
||||
searchResults := &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
Topks: []int64{1},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 101,
|
||||
FieldName: testVarCharField,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"text 1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the operator and expect error
|
||||
_, err := op.run(ctx, s.span, searchResults)
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "text field not in output field")
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSemanticHighlightOpMultipleFields() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Use a counter to return different results for different calls
|
||||
callCount := 0
|
||||
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
|
||||
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
|
||||
callCount++
|
||||
return [][]string{
|
||||
{fmt.Sprintf("<em>highlighted</em> text field%d-1", callCount)},
|
||||
{fmt.Sprintf("<em>highlighted</em> text field%d-2", callCount)},
|
||||
}, nil
|
||||
}).Build()
|
||||
defer mockProcess.UnPatch()
|
||||
|
||||
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101, 102}).Build()
|
||||
defer mockFieldIDs.UnPatch()
|
||||
|
||||
op := &semanticHighlightOperator{
|
||||
highlight: &highlight.SemanticHighlight{},
|
||||
}
|
||||
|
||||
// Create search results with multiple text fields
|
||||
searchResults := &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 2,
|
||||
Topks: []int64{2},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 101,
|
||||
FieldName: "field1",
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"text 1", "text 2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldId: 102,
|
||||
FieldName: "field2",
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"another text 1", "another text 2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
results, err := op.run(ctx, s.span, searchResults)
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
|
||||
// Verify results
|
||||
result := results[0].(*milvuspb.SearchResults)
|
||||
s.NotNil(result.Results.HighlightResults)
|
||||
s.Len(result.Results.HighlightResults, 2)
|
||||
|
||||
// Verify first field
|
||||
s.Equal("field1", result.Results.HighlightResults[0].FieldName)
|
||||
s.Len(result.Results.HighlightResults[0].Datas, 2)
|
||||
|
||||
// Verify second field
|
||||
s.Equal("field2", result.Results.HighlightResults[1].FieldName)
|
||||
s.Len(result.Results.HighlightResults[1].Datas, 2)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSemanticHighlightOpEmptyResults() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock Process to return empty results
|
||||
mockProcess := mockey.Mock((*highlight.SemanticHighlight).Process).To(
|
||||
func(h *highlight.SemanticHighlight, ctx context.Context, topks []int64, texts []string, params map[string]string) ([][]string, error) {
|
||||
return [][]string{}, nil
|
||||
}).Build()
|
||||
defer mockProcess.UnPatch()
|
||||
|
||||
mockFieldIDs := mockey.Mock((*highlight.SemanticHighlight).FieldIDs).Return([]int64{101}).Build()
|
||||
defer mockFieldIDs.UnPatch()
|
||||
|
||||
op := &semanticHighlightOperator{
|
||||
highlight: &highlight.SemanticHighlight{},
|
||||
}
|
||||
|
||||
// Create empty search results
|
||||
searchResults := &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 0,
|
||||
Topks: []int64{0},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 101,
|
||||
FieldName: testVarCharField,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
results, err := op.run(ctx, s.span, searchResults)
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
|
||||
// Verify results
|
||||
result := results[0].(*milvuspb.SearchResults)
|
||||
s.NotNil(result.Results.HighlightResults)
|
||||
s.Len(result.Results.HighlightResults, 1)
|
||||
s.Equal(testVarCharField, result.Results.HighlightResults[0].FieldName)
|
||||
s.Len(result.Results.HighlightResults[0].Datas, 0)
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestSearchPipeline() {
|
||||
collectionName := "test"
|
||||
task := &searchTask{
|
||||
|
||||
@ -618,6 +618,13 @@ func (t *searchTask) addHighlightTask(highlighter *commonpb.Highlighter, metricT
|
||||
switch highlighter.GetType() {
|
||||
case commonpb.HighlightType_Lexical:
|
||||
return t.createLexicalHighlighter(highlighter, metricType, annsField, placeholder, analyzerName)
|
||||
case commonpb.HighlightType_Semantic:
|
||||
h, err := newSemanticHighlighter(t, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()})
|
||||
if err != nil {
|
||||
return merr.WrapErrParameterInvalidMsg("Create SemanticHighlight failed: %v ", err)
|
||||
}
|
||||
t.highlighter = h
|
||||
return nil
|
||||
default:
|
||||
return merr.WrapErrParameterInvalidMsg("unsupported highlight type: %v", highlighter.GetType())
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
@ -43,6 +44,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||
"github.com/milvus-io/milvus/internal/util/function/highlight"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
@ -4920,7 +4922,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
schemaInfo := newSchemaInfo(schema)
|
||||
info := newSchemaInfo(schema)
|
||||
|
||||
placeholder := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{{
|
||||
@ -4934,7 +4936,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("lexical highlight success", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -4954,7 +4956,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("Lexical highlight with custom tags", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -4975,7 +4977,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("lexical highlight with wrong metric type", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
SearchRequest: &internalpb.SearchRequest{},
|
||||
request: &milvuspb.SearchRequest{},
|
||||
}
|
||||
@ -4991,7 +4993,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("lexical highlight with invalid pre_tags type", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -5015,9 +5017,9 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
schemaInfo := newSchemaInfo(schemaWithoutBM25)
|
||||
info := newSchemaInfo(schemaWithoutBM25)
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -5031,7 +5033,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("highlight without highlight search text", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -5045,7 +5047,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("highlight with invalid highlight search key", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -5059,7 +5061,7 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
|
||||
t.Run("highlight with unknown type", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: schemaInfo,
|
||||
schema: info,
|
||||
}
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
@ -5070,4 +5072,23 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
||||
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("semantic highlight success", func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
schema: info,
|
||||
}
|
||||
|
||||
queriesJSON, _ := json.Marshal([]string{"test_query"})
|
||||
inputFieldsJSON, _ := json.Marshal([]string{"text_field"})
|
||||
|
||||
highlighter := &commonpb.Highlighter{
|
||||
Type: commonpb.HighlightType_Semantic,
|
||||
Params: []*commonpb.KeyValuePair{{Key: "queries", Value: string(queriesJSON)}, {Key: "input_fields", Value: string(inputFieldsJSON)}},
|
||||
}
|
||||
|
||||
mockSemanticHighlight := mockey.Mock(highlight.NewSemanticHighlight).Return(&highlight.SemanticHighlight{}, nil).Build()
|
||||
defer mockSemanticHighlight.UnPatch()
|
||||
task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
||||
require.NotNil(t, task.highlighter)
|
||||
})
|
||||
}
|
||||
|
||||
146
internal/util/function/highlight/semantic_highlight.go
Normal file
146
internal/util/function/highlight/semantic_highlight.go
Normal file
@ -0,0 +1,146 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package highlight
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
)
|
||||
|
||||
type semanticHighlightProvider interface {
|
||||
highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error)
|
||||
maxBatch() int
|
||||
}
|
||||
|
||||
type baseSemanticHighlightProvider struct {
|
||||
batchSize int
|
||||
}
|
||||
|
||||
func (provider *baseSemanticHighlightProvider) maxBatch() int {
|
||||
return provider.batchSize
|
||||
}
|
||||
|
||||
type SemanticHighlight struct {
|
||||
fieldIDs []int64
|
||||
provider semanticHighlightProvider
|
||||
queries []string
|
||||
}
|
||||
|
||||
const (
|
||||
queryKeyName string = "queries"
|
||||
inputFieldKeyName string = "input_fields"
|
||||
)
|
||||
|
||||
func NewSemanticHighlight(collSchema *schemapb.CollectionSchema, params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (*SemanticHighlight, error) {
|
||||
queries := []string{}
|
||||
inputField := []string{}
|
||||
for _, param := range params {
|
||||
switch param.Key {
|
||||
case queryKeyName:
|
||||
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
|
||||
return nil, fmt.Errorf("Parse queries failed, err: %v", err)
|
||||
}
|
||||
case inputFieldKeyName:
|
||||
if err := json.Unmarshal([]byte(param.Value), &inputField); err != nil {
|
||||
return nil, fmt.Errorf("Parse input_field failed, err: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(queries) == 0 {
|
||||
return nil, fmt.Errorf("queries is required")
|
||||
}
|
||||
|
||||
if len(inputField) == 0 {
|
||||
return nil, fmt.Errorf("input_field is required")
|
||||
}
|
||||
|
||||
fieldIDMap := make(map[string]*schemapb.FieldSchema)
|
||||
for _, field := range collSchema.Fields {
|
||||
fieldIDMap[field.Name] = field
|
||||
}
|
||||
|
||||
fieldIDs := []int64{}
|
||||
for _, fieldName := range inputField {
|
||||
field, ok := fieldIDMap[fieldName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input_field %s not found", fieldName)
|
||||
}
|
||||
if field.DataType != schemapb.DataType_VarChar && field.DataType != schemapb.DataType_Text {
|
||||
return nil, fmt.Errorf("input_field %s is not a VarChar or Text field", fieldName)
|
||||
}
|
||||
|
||||
fieldIDs = append(fieldIDs, field.FieldID)
|
||||
}
|
||||
|
||||
// TODO: support other providers if have more providers
|
||||
provider, err := newZillizHighlightProvider(params, conf, extraInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SemanticHighlight{fieldIDs: fieldIDs, provider: provider, queries: queries}, nil
|
||||
}
|
||||
|
||||
func (highlight *SemanticHighlight) FieldIDs() []int64 {
|
||||
return highlight.fieldIDs
|
||||
}
|
||||
|
||||
func (highlight *SemanticHighlight) processOneQuery(ctx context.Context, query string, data []string, params map[string]string) ([][]string, error) {
|
||||
if len(data) == 0 {
|
||||
return [][]string{}, nil
|
||||
}
|
||||
highlights, err := highlight.provider.highlight(ctx, query, data, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(highlights) != len(data) {
|
||||
return nil, fmt.Errorf("Highlights size must equal to data size, but got highlights size [%d], data size [%d]", len(highlights), len(data))
|
||||
}
|
||||
return highlights, nil
|
||||
}
|
||||
|
||||
func (highlight *SemanticHighlight) Process(ctx context.Context, topks []int64, data []string, params map[string]string) ([][]string, error) {
|
||||
nq := len(topks)
|
||||
if len(highlight.queries) != nq {
|
||||
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", nq, len(highlight.queries), highlight.queries)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return [][]string{}, nil
|
||||
}
|
||||
|
||||
highlights := make([][]string, 0, len(data))
|
||||
start := int64(0)
|
||||
|
||||
for i, query := range highlight.queries {
|
||||
size := topks[i]
|
||||
singleHighlights, err := highlight.processOneQuery(ctx, query, data[start:start+size], params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
highlights = append(highlights, singleHighlights...)
|
||||
start += size
|
||||
}
|
||||
return highlights, nil
|
||||
}
|
||||
545
internal/util/function/highlight/semantic_highlight_test.go
Normal file
545
internal/util/function/highlight/semantic_highlight_test.go
Normal file
@ -0,0 +1,545 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package highlight
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
|
||||
)
|
||||
|
||||
func TestSemanticHighlight(t *testing.T) {
|
||||
suite.Run(t, new(SemanticHighlightSuite))
|
||||
}
|
||||
|
||||
type SemanticHighlightSuite struct {
|
||||
suite.Suite
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) SetupTest() {
|
||||
s.schema = &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "title", DataType: schemapb.DataType_VarChar},
|
||||
{FieldID: 102, Name: "content", DataType: schemapb.DataType_Text},
|
||||
{FieldID: 103, Name: "description", DataType: schemapb.DataType_VarChar},
|
||||
{FieldID: 104, Name: "embedding", DataType: schemapb.DataType_FloatVector},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_Success() {
|
||||
queries := []string{"machine learning", "artificial intelligence"}
|
||||
inputFields := []string{"title", "content"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(highlight)
|
||||
s.Equal([]int64{101, 102}, highlight.FieldIDs())
|
||||
s.Equal(queries, highlight.queries)
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingQueries() {
|
||||
inputFields := []string{"title"}
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "queries is required")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_MissingInputFields() {
|
||||
queries := []string{"machine learning"}
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "input_field is required")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidQueriesJSON() {
|
||||
inputFields := []string{"title"}
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: "invalid json"},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "Parse queries failed")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidInputFieldsJSON() {
|
||||
queries := []string{"machine learning"}
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: "invalid json"},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "Parse input_field failed")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_FieldNotFound() {
|
||||
queries := []string{"machine learning"}
|
||||
inputFields := []string{"nonexistent_field"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestNewSemanticHighlight_InvalidFieldType() {
|
||||
queries := []string{"machine learning"}
|
||||
inputFields := []string{"embedding"} // FloatVector, not VarChar or Text
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlight)
|
||||
s.Contains(err.Error(), "is not a VarChar or Text field")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcessOneQuery_Success() {
|
||||
queries := []string{"machine learning"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
expectedHighlights := [][]string{
|
||||
{"machine learning"},
|
||||
{"machine"},
|
||||
}
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
return expectedHighlights, nil
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{"Machine learning is a subset of AI", "Machine learning is powerful"}
|
||||
highlights, err := highlight.processOneQuery(ctx, "machine learning", data, nil)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(expectedHighlights, highlights)
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcessOneQuery_Error() {
|
||||
queries := []string{"test query"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
expectedError := errors.New("highlight service error")
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
return nil, expectedError
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{"test document"}
|
||||
highlights, err := highlight.processOneQuery(ctx, "test query", data, nil)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlights)
|
||||
s.Equal(expectedError, err)
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcess_Success() {
|
||||
queries := []string{"machine learning", "deep learning"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
expectedHighlights1 := [][]string{
|
||||
{"machine learning", "deep learning"},
|
||||
}
|
||||
expectedHighlights2 := [][]string{
|
||||
{"deep learning", "machine learning"},
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, query string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
callCount++
|
||||
if query == "machine learning" {
|
||||
return expectedHighlights1, nil
|
||||
}
|
||||
return expectedHighlights2, nil
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{"Machine learning document", "Deep learning document"}
|
||||
highlights, err := highlight.Process(ctx, []int64{1, 1}, data, nil)
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(highlights)
|
||||
s.Equal(2, callCount, "Should call highlight twice for two queries")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcess_NqMismatch() {
|
||||
queries := []string{"machine learning"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{"test document"}
|
||||
highlights, err := highlight.Process(ctx, []int64{1, 1, 1}, data, nil) // nq=3 but queries has only 1
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlights)
|
||||
s.Contains(err.Error(), "nq must equal to queries size")
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcess_ProviderError() {
|
||||
queries := []string{"test query"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
expectedError := errors.New("provider error")
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
return nil, expectedError
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{"test document"}
|
||||
highlights, err := highlight.Process(ctx, []int64{1}, data, nil)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlights)
|
||||
s.Equal(expectedError, err)
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestProcess_EmptyData() {
|
||||
queries := []string{"test query", "test query 2", "test query 3"}
|
||||
inputFields := []string{"title"}
|
||||
|
||||
queriesJSON, _ := json.Marshal(queries)
|
||||
inputFieldsJSON, _ := json.Marshal(inputFields)
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, texts []string, _ map[string]string) ([][]string, error) {
|
||||
return [][]string{texts}, nil
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: queryKeyName, Value: string(queriesJSON)},
|
||||
{Key: inputFieldKeyName, Value: string(inputFieldsJSON)},
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}
|
||||
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
highlight, err := NewSemanticHighlight(s.schema, params, conf, extraInfo)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
data := []string{}
|
||||
highlights, err := highlight.Process(ctx, []int64{0, 0, 0}, data, nil)
|
||||
|
||||
s.NoError(err)
|
||||
s.NotNil(highlights)
|
||||
|
||||
data2 := []string{"test document"}
|
||||
|
||||
highlights2, err := highlight.Process(ctx, []int64{0, 1, 0}, data2, nil)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(1, len(highlights2))
|
||||
s.Equal([][]string{{"test document"}}, highlights2)
|
||||
}
|
||||
|
||||
func (s *SemanticHighlightSuite) TestBaseSemanticHighlightProvider_MaxBatch() {
|
||||
provider := &baseSemanticHighlightProvider{batchSize: 128}
|
||||
s.Equal(128, provider.maxBatch())
|
||||
|
||||
provider2 := &baseSemanticHighlightProvider{batchSize: 32}
|
||||
s.Equal(32, provider2.maxBatch())
|
||||
}
|
||||
@ -0,0 +1,78 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package highlight
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
|
||||
)
|
||||
|
||||
type zillizHighlightProvider struct {
|
||||
baseSemanticHighlightProvider
|
||||
client *zilliz.ZillizClient
|
||||
modelName string
|
||||
queries []string
|
||||
|
||||
modelParams map[string]string
|
||||
}
|
||||
|
||||
func newZillizHighlightProvider(params []*commonpb.KeyValuePair, conf map[string]string, extraInfo *models.ModelExtraInfo) (semanticHighlightProvider, error) {
|
||||
var modelDeploymentID string
|
||||
var err error
|
||||
maxBatch := 64
|
||||
modelParams := map[string]string{}
|
||||
for _, param := range params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case models.ModelDeploymentIDKey:
|
||||
modelDeploymentID = param.Value
|
||||
case models.MaxClientBatchSizeParamKey:
|
||||
if maxBatch, err = strconv.Atoi(param.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
modelParams[param.Key] = param.Value
|
||||
}
|
||||
}
|
||||
|
||||
c, err := zilliz.NewZilliClient(modelDeploymentID, extraInfo.ClusterID, extraInfo.DBName, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider := zillizHighlightProvider{
|
||||
baseSemanticHighlightProvider: baseSemanticHighlightProvider{batchSize: maxBatch},
|
||||
client: c,
|
||||
modelParams: modelParams,
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func (h *zillizHighlightProvider) highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) {
|
||||
highlights, err := h.client.Highlight(ctx, query, texts, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return highlights, nil
|
||||
}
|
||||
@ -0,0 +1,240 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package highlight
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/zilliz"
|
||||
)
|
||||
|
||||
func TestZillizHighlightProvider(t *testing.T) {
|
||||
suite.Run(t, new(ZillizHighlightProviderSuite))
|
||||
}
|
||||
|
||||
type ZillizHighlightProviderSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_Success() {
|
||||
tests := []struct {
|
||||
name string
|
||||
params []*commonpb.KeyValuePair
|
||||
conf map[string]string
|
||||
extraInfo *models.ModelExtraInfo
|
||||
expectedBatch int
|
||||
expectedParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "basic configuration",
|
||||
params: []*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
},
|
||||
conf: map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
},
|
||||
extraInfo: &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
},
|
||||
expectedBatch: 64, // default batch size
|
||||
expectedParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "with custom batch size",
|
||||
params: []*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
{Key: models.MaxClientBatchSizeParamKey, Value: "32"},
|
||||
},
|
||||
conf: map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
},
|
||||
extraInfo: &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
},
|
||||
expectedBatch: 32,
|
||||
expectedParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "with additional model parameters",
|
||||
params: []*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
{Key: models.MaxClientBatchSizeParamKey, Value: "16"},
|
||||
{Key: "threshold", Value: "0.7"},
|
||||
},
|
||||
conf: map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
},
|
||||
extraInfo: &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
},
|
||||
expectedBatch: 16,
|
||||
expectedParams: map[string]string{
|
||||
"threshold": "0.7",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// Note: This test will fail with actual connection since we can't easily mock the global client manager
|
||||
// But we can test the parameter parsing logic by checking the error message
|
||||
provider, err := newZillizHighlightProvider(tt.params, tt.conf, tt.extraInfo)
|
||||
|
||||
// Since we can't easily mock the zilliz client creation, we expect a connection error
|
||||
// but we can verify that the parameters were parsed correctly by checking the error doesn't relate to parameter parsing
|
||||
if err != nil {
|
||||
// Connection errors are expected in unit tests
|
||||
s.Contains(err.Error(), "Connect model serving failed", "Expected connection error, got: %v", err)
|
||||
} else {
|
||||
// If somehow the connection succeeds, verify the provider was created correctly
|
||||
s.NotNil(provider)
|
||||
zillizProvider, ok := provider.(*zillizHighlightProvider)
|
||||
s.True(ok)
|
||||
s.Equal(tt.expectedBatch, zillizProvider.maxBatch())
|
||||
s.Equal(tt.expectedParams, zillizProvider.modelParams)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ZillizHighlightProviderSuite) TestNewZillizHighlightProvider_InvalidBatchSize() {
|
||||
tests := []struct {
|
||||
name string
|
||||
batchValue string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "invalid number format",
|
||||
batchValue: "not_a_number",
|
||||
expectedErr: "invalid syntax",
|
||||
},
|
||||
{
|
||||
name: "empty batch size",
|
||||
batchValue: "",
|
||||
expectedErr: "invalid syntax",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
{Key: models.MaxClientBatchSizeParamKey, Value: tt.batchValue},
|
||||
}
|
||||
conf := map[string]string{
|
||||
"endpoint": "localhost:8080",
|
||||
}
|
||||
extraInfo := &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
}
|
||||
|
||||
provider, err := newZillizHighlightProvider(params, conf, extraInfo)
|
||||
s.Error(err)
|
||||
s.Nil(provider)
|
||||
s.Contains(err.Error(), tt.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Success() {
|
||||
// Set up expected behavior
|
||||
ctx := context.Background()
|
||||
query := "machine learning"
|
||||
texts := []string{
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Deep learning is a type of machine learning.",
|
||||
"Natural language processing uses machine learning techniques.",
|
||||
}
|
||||
params := map[string]string{"param1": "value1"}
|
||||
expectedHighlights := [][]string{
|
||||
{"Machine learning", "artificial intelligence"},
|
||||
{"Deep learning", "machine learning"},
|
||||
{"Natural language processing", "machine learning"},
|
||||
}
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
return expectedHighlights, nil
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotNil(provider)
|
||||
|
||||
highlights, err := provider.highlight(ctx, query, texts, params)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(expectedHighlights, highlights)
|
||||
}
|
||||
|
||||
func (s *ZillizHighlightProviderSuite) TestZillizHighlightProvider_Highlight_Error() {
|
||||
// Set up expected behavior with error
|
||||
ctx := context.Background()
|
||||
query := "test query"
|
||||
texts := []string{"doc1", "doc2", "doc3"}
|
||||
params := map[string]string{"param1": "value1"}
|
||||
expectedError := errors.New("highlight service error")
|
||||
|
||||
mock1 := mockey.Mock(zilliz.NewZilliClient).To(func(_ string, _ string, _ string, _ map[string]string) (*zilliz.ZillizClient, error) {
|
||||
return &zilliz.ZillizClient{}, nil
|
||||
}).Build()
|
||||
defer mock1.UnPatch()
|
||||
|
||||
mock2 := mockey.Mock((*zilliz.ZillizClient).Highlight).To(func(_ *zilliz.ZillizClient, _ context.Context, _ string, _ []string, _ map[string]string) ([][]string, error) {
|
||||
return nil, expectedError
|
||||
}).Build()
|
||||
defer mock2.UnPatch()
|
||||
|
||||
provider, err := newZillizHighlightProvider([]*commonpb.KeyValuePair{
|
||||
{Key: models.ModelDeploymentIDKey, Value: "test-deployment"},
|
||||
}, map[string]string{"endpoint": "localhost:8080"}, &models.ModelExtraInfo{
|
||||
ClusterID: "test-cluster",
|
||||
DBName: "test-db",
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotNil(provider)
|
||||
|
||||
// Test the highlight method
|
||||
highlights, err := provider.highlight(ctx, query, texts, params)
|
||||
|
||||
s.Error(err)
|
||||
s.Nil(highlights)
|
||||
s.Equal(expectedError, err)
|
||||
}
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/backoff"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
@ -33,6 +34,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/modelservicepb"
|
||||
)
|
||||
|
||||
@ -93,7 +95,10 @@ func (m *clientManager) GetConn(clientConf *clientConfig) (*grpc.ClientConn, err
|
||||
|
||||
if m.conn != nil {
|
||||
if clientConf.endpoint != m.config.endpoint {
|
||||
_ = m.conn.Close()
|
||||
err := m.conn.Close()
|
||||
if err != nil {
|
||||
log.Warn("Close connect failed", zap.String("endpoint", m.config.endpoint), zap.Error(err))
|
||||
}
|
||||
m.conn = nil
|
||||
} else {
|
||||
state := m.conn.GetState()
|
||||
@ -247,3 +252,22 @@ func (c *ZillizClient) Rerank(ctx context.Context, query string, texts []string,
|
||||
}
|
||||
return res.Scores, nil
|
||||
}
|
||||
|
||||
func (c *ZillizClient) Highlight(ctx context.Context, query string, texts []string, params map[string]string) ([][]string, error) {
|
||||
stub := modelservicepb.NewHighlightServiceClient(c.conn)
|
||||
req := &modelservicepb.HighlightRequest{
|
||||
Query: query,
|
||||
Documents: texts,
|
||||
Params: params,
|
||||
}
|
||||
ctx = c.setMeta(ctx)
|
||||
res, err := stub.Highlight(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
highlights := make([][]string, 0, len(res.GetResults()))
|
||||
for _, ret := range res.GetResults() {
|
||||
highlights = append(highlights, ret.GetSentences())
|
||||
}
|
||||
return highlights, nil
|
||||
}
|
||||
|
||||
@ -63,6 +63,19 @@ func (m *mockRerankServer) Rerank(ctx context.Context, req *modelservicepb.TextR
|
||||
return m.response, nil
|
||||
}
|
||||
|
||||
type mockHighlightServer struct {
|
||||
modelservicepb.UnimplementedHighlightServiceServer
|
||||
response *modelservicepb.HighlightResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockHighlightServer) Highlight(ctx context.Context, req *modelservicepb.HighlightRequest) (*modelservicepb.HighlightResponse, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.response, nil
|
||||
}
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -559,3 +572,103 @@ func TestZillizClient_Embedding_EmptyResponse(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, embeddings)
|
||||
}
|
||||
|
||||
func TestZillizClient_Highlight(t *testing.T) {
|
||||
// Setup mock server
|
||||
s, lis, dialer := setupMockServer(t)
|
||||
defer lis.Close()
|
||||
defer s.Stop()
|
||||
|
||||
mockServer := &mockHighlightServer{
|
||||
response: &modelservicepb.HighlightResponse{
|
||||
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
|
||||
Results: []*modelservicepb.HighlightResult{
|
||||
{
|
||||
Sentences: []string{"highlight1", "highlight2"},
|
||||
},
|
||||
{
|
||||
Sentences: []string{"highlight3", "highlight4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelservicepb.RegisterHighlightServiceServer(s, mockServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
fmt.Printf("Server exited with error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create connection
|
||||
conn, err := grpc.DialContext(
|
||||
context.Background(),
|
||||
"bufnet",
|
||||
grpc.WithContextDialer(dialer),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := &ZillizClient{
|
||||
modelDeploymentID: "test-deployment",
|
||||
clusterID: "test-cluster",
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
// Test successful highlight
|
||||
ctx := context.Background()
|
||||
query := "test query"
|
||||
texts := []string{"doc1", "doc2", "doc3"}
|
||||
params := map[string]string{"param1": "value1"}
|
||||
highlights, err := client.Highlight(ctx, query, texts, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, [][]string{{"highlight1", "highlight2"}, {"highlight3", "highlight4"}}, highlights)
|
||||
}
|
||||
|
||||
func TestZillizClient_Highlight_Error(t *testing.T) {
|
||||
// Setup mock server with error
|
||||
s, lis, dialer := setupMockServer(t)
|
||||
defer lis.Close()
|
||||
defer s.Stop()
|
||||
|
||||
mockServer := &mockHighlightServer{
|
||||
err: assert.AnError,
|
||||
}
|
||||
|
||||
modelservicepb.RegisterHighlightServiceServer(s, mockServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
fmt.Printf("Server exited with error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create connection
|
||||
conn, err := grpc.DialContext(
|
||||
context.Background(),
|
||||
"bufnet",
|
||||
grpc.WithContextDialer(dialer),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := &ZillizClient{
|
||||
modelDeploymentID: "test-deployment",
|
||||
clusterID: "test-cluster",
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
// Test highlight with error
|
||||
ctx := context.Background()
|
||||
query := "test query"
|
||||
texts := []string{"doc1", "doc2", "doc3"}
|
||||
params := map[string]string{"param1": "value1"}
|
||||
highlights, err := client.Highlight(ctx, query, texts, params)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, highlights)
|
||||
}
|
||||
|
||||
@ -11,6 +11,10 @@ service RerankService {
|
||||
rpc Rerank(TextRerankRequest) returns (TextRerankResponse);
|
||||
}
|
||||
|
||||
service HighlightService {
|
||||
rpc Highlight(HighlightRequest) returns (HighlightResponse);
|
||||
}
|
||||
|
||||
message Status {
|
||||
int32 code = 1;
|
||||
string msg = 2;
|
||||
@ -68,3 +72,20 @@ message TextRerankResponse {
|
||||
map<string, string> extra_info = 2;
|
||||
repeated float scores = 3;
|
||||
}
|
||||
|
||||
message HighlightRequest {
|
||||
string model = 1;
|
||||
string query = 2;
|
||||
repeated string documents = 3;
|
||||
map<string, string> params = 4;
|
||||
}
|
||||
|
||||
message HighlightResult {
|
||||
repeated string sentences = 1;
|
||||
}
|
||||
|
||||
message HighlightResponse {
|
||||
Status status = 1;
|
||||
map<string, string> extra_info = 2;
|
||||
repeated HighlightResult results = 3;
|
||||
}
|
||||
|
||||
@ -620,6 +620,187 @@ func (x *TextRerankResponse) GetScores() []float32 {
|
||||
return nil
|
||||
}
|
||||
|
||||
type HighlightRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Model string `protobuf:"bytes,1,opt,name=model,proto3" json:"model,omitempty"`
|
||||
Query string `protobuf:"bytes,2,opt,name=query,proto3" json:"query,omitempty"`
|
||||
Documents []string `protobuf:"bytes,3,rep,name=documents,proto3" json:"documents,omitempty"`
|
||||
Params map[string]string `protobuf:"bytes,4,rep,name=params,proto3" json:"params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) Reset() {
|
||||
*x = HighlightRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_model_service_proto_msgTypes[9]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*HighlightRequest) ProtoMessage() {}
|
||||
|
||||
func (x *HighlightRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_model_service_proto_msgTypes[9]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use HighlightRequest.ProtoReflect.Descriptor instead.
|
||||
func (*HighlightRequest) Descriptor() ([]byte, []int) {
|
||||
return file_model_service_proto_rawDescGZIP(), []int{9}
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) GetModel() string {
|
||||
if x != nil {
|
||||
return x.Model
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) GetQuery() string {
|
||||
if x != nil {
|
||||
return x.Query
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) GetDocuments() []string {
|
||||
if x != nil {
|
||||
return x.Documents
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *HighlightRequest) GetParams() map[string]string {
|
||||
if x != nil {
|
||||
return x.Params
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type HighlightResult struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Sentences []string `protobuf:"bytes,1,rep,name=sentences,proto3" json:"sentences,omitempty"`
|
||||
}
|
||||
|
||||
func (x *HighlightResult) Reset() {
|
||||
*x = HighlightResult{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_model_service_proto_msgTypes[10]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *HighlightResult) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*HighlightResult) ProtoMessage() {}
|
||||
|
||||
func (x *HighlightResult) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_model_service_proto_msgTypes[10]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use HighlightResult.ProtoReflect.Descriptor instead.
|
||||
func (*HighlightResult) Descriptor() ([]byte, []int) {
|
||||
return file_model_service_proto_rawDescGZIP(), []int{10}
|
||||
}
|
||||
|
||||
func (x *HighlightResult) GetSentences() []string {
|
||||
if x != nil {
|
||||
return x.Sentences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type HighlightResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Status *Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"`
|
||||
ExtraInfo map[string]string `protobuf:"bytes,2,rep,name=extra_info,json=extraInfo,proto3" json:"extra_info,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
Results []*HighlightResult `protobuf:"bytes,3,rep,name=results,proto3" json:"results,omitempty"`
|
||||
}
|
||||
|
||||
func (x *HighlightResponse) Reset() {
|
||||
*x = HighlightResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_model_service_proto_msgTypes[11]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *HighlightResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*HighlightResponse) ProtoMessage() {}
|
||||
|
||||
func (x *HighlightResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_model_service_proto_msgTypes[11]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use HighlightResponse.ProtoReflect.Descriptor instead.
|
||||
func (*HighlightResponse) Descriptor() ([]byte, []int) {
|
||||
return file_model_service_proto_rawDescGZIP(), []int{11}
|
||||
}
|
||||
|
||||
func (x *HighlightResponse) GetStatus() *Status {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *HighlightResponse) GetExtraInfo() map[string]string {
|
||||
if x != nil {
|
||||
return x.ExtraInfo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *HighlightResponse) GetResults() []*HighlightResult {
|
||||
if x != nil {
|
||||
return x.Results
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_model_service_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_model_service_proto_rawDesc = []byte{
|
||||
@ -729,27 +910,72 @@ var file_model_service_proto_rawDesc = []byte{
|
||||
0x78, 0x74, 0x72, 0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
|
||||
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65,
|
||||
0x78, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69,
|
||||
0x63, 0x65, 0x12, 0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12,
|
||||
0x2f, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d,
|
||||
0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74,
|
||||
0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x30, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
|
||||
0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78,
|
||||
0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x32, 0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76,
|
||||
0x69, 0x63, 0x65, 0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e,
|
||||
0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64,
|
||||
0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65,
|
||||
0x72, 0x61, 0x6e, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xe8, 0x01, 0x0a, 0x10, 0x48, 0x69,
|
||||
0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14,
|
||||
0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d,
|
||||
0x6f, 0x64, 0x65, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x6f,
|
||||
0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x64,
|
||||
0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x4f, 0x0a, 0x06, 0x70, 0x61, 0x72, 0x61,
|
||||
0x6d, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
||||
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
|
||||
0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72,
|
||||
0x79, 0x52, 0x06, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x39, 0x0a, 0x0b, 0x50, 0x61, 0x72,
|
||||
0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61,
|
||||
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
|
||||
0x3a, 0x02, 0x38, 0x01, 0x22, 0x2f, 0x0a, 0x0f, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68,
|
||||
0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x65, 0x6e, 0x74, 0x65,
|
||||
0x6e, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x6e, 0x74,
|
||||
0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0xae, 0x02, 0x0a, 0x11, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69,
|
||||
0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x06, 0x73,
|
||||
0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x6d, 0x69,
|
||||
0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c,
|
||||
0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61,
|
||||
0x6e, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69,
|
||||
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d,
|
||||
0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32,
|
||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76,
|
||||
0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06,
|
||||
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x5a, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f,
|
||||
0x69, 0x6e, 0x66, 0x6f, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x3b, 0x2e, 0x6d, 0x69, 0x6c,
|
||||
0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x45, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e,
|
||||
0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x6e,
|
||||
0x66, 0x6f, 0x12, 0x44, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x03, 0x20,
|
||||
0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e,
|
||||
0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52,
|
||||
0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x1a, 0x3c, 0x0a, 0x0e, 0x45, 0x78, 0x74, 0x72,
|
||||
0x61, 0x49, 0x6e, 0x66, 0x6f, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c,
|
||||
0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0x86, 0x01, 0x0a, 0x14, 0x54, 0x65, 0x78, 0x74, 0x45,
|
||||
0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
|
||||
0x6e, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x2f, 0x2e, 0x6d,
|
||||
0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65,
|
||||
0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d, 0x62,
|
||||
0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e,
|
||||
0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64,
|
||||
0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x45, 0x6d,
|
||||
0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32,
|
||||
0x76, 0x0a, 0x0d, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
|
||||
0x12, 0x65, 0x0a, 0x06, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x12, 0x2c, 0x2e, 0x6d, 0x69, 0x6c,
|
||||
0x76, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e,
|
||||
0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2d, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
||||
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
|
||||
0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x72, 0x61, 0x6e, 0x6b, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x7a, 0x0a, 0x10, 0x48, 0x69, 0x67, 0x68, 0x6c,
|
||||
0x69, 0x67, 0x68, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x66, 0x0a, 0x09, 0x48,
|
||||
0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x12, 0x2b, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
||||
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72,
|
||||
0x76, 0x69, 0x63, 0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2c, 0x2e, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63,
|
||||
0x65, 0x2e, 0x48, 0x69, 0x67, 0x68, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x42, 0x39, 0x5a, 0x37, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
|
||||
0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76,
|
||||
0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f,
|
||||
0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x70, 0x62, 0x62, 0x06,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@ -765,7 +991,7 @@ func file_model_service_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_model_service_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 13)
|
||||
var file_model_service_proto_msgTypes = make([]protoimpl.MessageInfo, 18)
|
||||
var file_model_service_proto_goTypes = []interface{}{
|
||||
(DenseVector_DType)(0), // 0: milvus.proto.modelservice.DenseVector.DType
|
||||
(*Status)(nil), // 1: milvus.proto.modelservice.Status
|
||||
@ -777,33 +1003,44 @@ var file_model_service_proto_goTypes = []interface{}{
|
||||
(*TextEmbeddingResponse)(nil), // 7: milvus.proto.modelservice.TextEmbeddingResponse
|
||||
(*TextRerankRequest)(nil), // 8: milvus.proto.modelservice.TextRerankRequest
|
||||
(*TextRerankResponse)(nil), // 9: milvus.proto.modelservice.TextRerankResponse
|
||||
nil, // 10: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
|
||||
nil, // 11: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
|
||||
nil, // 12: milvus.proto.modelservice.TextRerankRequest.ParamsEntry
|
||||
nil, // 13: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
|
||||
(*HighlightRequest)(nil), // 10: milvus.proto.modelservice.HighlightRequest
|
||||
(*HighlightResult)(nil), // 11: milvus.proto.modelservice.HighlightResult
|
||||
(*HighlightResponse)(nil), // 12: milvus.proto.modelservice.HighlightResponse
|
||||
nil, // 13: milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
|
||||
nil, // 14: milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
|
||||
nil, // 15: milvus.proto.modelservice.TextRerankRequest.ParamsEntry
|
||||
nil, // 16: milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
|
||||
nil, // 17: milvus.proto.modelservice.HighlightRequest.ParamsEntry
|
||||
nil, // 18: milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry
|
||||
}
|
||||
var file_model_service_proto_depIdxs = []int32{
|
||||
0, // 0: milvus.proto.modelservice.DenseVector.dtype:type_name -> milvus.proto.modelservice.DenseVector.DType
|
||||
2, // 1: milvus.proto.modelservice.MultiVector.token_vectors:type_name -> milvus.proto.modelservice.DenseVector
|
||||
10, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
|
||||
13, // 2: milvus.proto.modelservice.TextEmbeddingRequest.params:type_name -> milvus.proto.modelservice.TextEmbeddingRequest.ParamsEntry
|
||||
2, // 3: milvus.proto.modelservice.EmbeddingResult.dense:type_name -> milvus.proto.modelservice.DenseVector
|
||||
3, // 4: milvus.proto.modelservice.EmbeddingResult.sparse:type_name -> milvus.proto.modelservice.SparseVector
|
||||
4, // 5: milvus.proto.modelservice.EmbeddingResult.multi_vector:type_name -> milvus.proto.modelservice.MultiVector
|
||||
1, // 6: milvus.proto.modelservice.TextEmbeddingResponse.status:type_name -> milvus.proto.modelservice.Status
|
||||
11, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
|
||||
14, // 7: milvus.proto.modelservice.TextEmbeddingResponse.extra_info:type_name -> milvus.proto.modelservice.TextEmbeddingResponse.ExtraInfoEntry
|
||||
6, // 8: milvus.proto.modelservice.TextEmbeddingResponse.results:type_name -> milvus.proto.modelservice.EmbeddingResult
|
||||
12, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry
|
||||
15, // 9: milvus.proto.modelservice.TextRerankRequest.params:type_name -> milvus.proto.modelservice.TextRerankRequest.ParamsEntry
|
||||
1, // 10: milvus.proto.modelservice.TextRerankResponse.status:type_name -> milvus.proto.modelservice.Status
|
||||
13, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
|
||||
5, // 12: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest
|
||||
8, // 13: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest
|
||||
7, // 14: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse
|
||||
9, // 15: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse
|
||||
14, // [14:16] is the sub-list for method output_type
|
||||
12, // [12:14] is the sub-list for method input_type
|
||||
12, // [12:12] is the sub-list for extension type_name
|
||||
12, // [12:12] is the sub-list for extension extendee
|
||||
0, // [0:12] is the sub-list for field type_name
|
||||
16, // 11: milvus.proto.modelservice.TextRerankResponse.extra_info:type_name -> milvus.proto.modelservice.TextRerankResponse.ExtraInfoEntry
|
||||
17, // 12: milvus.proto.modelservice.HighlightRequest.params:type_name -> milvus.proto.modelservice.HighlightRequest.ParamsEntry
|
||||
1, // 13: milvus.proto.modelservice.HighlightResponse.status:type_name -> milvus.proto.modelservice.Status
|
||||
18, // 14: milvus.proto.modelservice.HighlightResponse.extra_info:type_name -> milvus.proto.modelservice.HighlightResponse.ExtraInfoEntry
|
||||
11, // 15: milvus.proto.modelservice.HighlightResponse.results:type_name -> milvus.proto.modelservice.HighlightResult
|
||||
5, // 16: milvus.proto.modelservice.TextEmbeddingService.Embedding:input_type -> milvus.proto.modelservice.TextEmbeddingRequest
|
||||
8, // 17: milvus.proto.modelservice.RerankService.Rerank:input_type -> milvus.proto.modelservice.TextRerankRequest
|
||||
10, // 18: milvus.proto.modelservice.HighlightService.Highlight:input_type -> milvus.proto.modelservice.HighlightRequest
|
||||
7, // 19: milvus.proto.modelservice.TextEmbeddingService.Embedding:output_type -> milvus.proto.modelservice.TextEmbeddingResponse
|
||||
9, // 20: milvus.proto.modelservice.RerankService.Rerank:output_type -> milvus.proto.modelservice.TextRerankResponse
|
||||
12, // 21: milvus.proto.modelservice.HighlightService.Highlight:output_type -> milvus.proto.modelservice.HighlightResponse
|
||||
19, // [19:22] is the sub-list for method output_type
|
||||
16, // [16:19] is the sub-list for method input_type
|
||||
16, // [16:16] is the sub-list for extension type_name
|
||||
16, // [16:16] is the sub-list for extension extendee
|
||||
0, // [0:16] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_model_service_proto_init() }
|
||||
@ -920,6 +1157,42 @@ func file_model_service_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_model_service_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*HighlightRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_model_service_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*HighlightResult); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_model_service_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*HighlightResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
@ -927,9 +1200,9 @@ func file_model_service_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_model_service_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 13,
|
||||
NumMessages: 18,
|
||||
NumExtensions: 0,
|
||||
NumServices: 2,
|
||||
NumServices: 3,
|
||||
},
|
||||
GoTypes: file_model_service_proto_goTypes,
|
||||
DependencyIndexes: file_model_service_proto_depIdxs,
|
||||
|
||||
@ -193,3 +193,91 @@ var RerankService_ServiceDesc = grpc.ServiceDesc{
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "model_service.proto",
|
||||
}
|
||||
|
||||
const (
|
||||
HighlightService_Highlight_FullMethodName = "/milvus.proto.modelservice.HighlightService/Highlight"
|
||||
)
|
||||
|
||||
// HighlightServiceClient is the client API for HighlightService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type HighlightServiceClient interface {
|
||||
Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error)
|
||||
}
|
||||
|
||||
type highlightServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewHighlightServiceClient(cc grpc.ClientConnInterface) HighlightServiceClient {
|
||||
return &highlightServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *highlightServiceClient) Highlight(ctx context.Context, in *HighlightRequest, opts ...grpc.CallOption) (*HighlightResponse, error) {
|
||||
out := new(HighlightResponse)
|
||||
err := c.cc.Invoke(ctx, HighlightService_Highlight_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// HighlightServiceServer is the server API for HighlightService service.
|
||||
// All implementations should embed UnimplementedHighlightServiceServer
|
||||
// for forward compatibility
|
||||
type HighlightServiceServer interface {
|
||||
Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error)
|
||||
}
|
||||
|
||||
// UnimplementedHighlightServiceServer should be embedded to have forward compatible implementations.
|
||||
type UnimplementedHighlightServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedHighlightServiceServer) Highlight(context.Context, *HighlightRequest) (*HighlightResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Highlight not implemented")
|
||||
}
|
||||
|
||||
// UnsafeHighlightServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to HighlightServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeHighlightServiceServer interface {
|
||||
mustEmbedUnimplementedHighlightServiceServer()
|
||||
}
|
||||
|
||||
func RegisterHighlightServiceServer(s grpc.ServiceRegistrar, srv HighlightServiceServer) {
|
||||
s.RegisterService(&HighlightService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _HighlightService_Highlight_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(HighlightRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(HighlightServiceServer).Highlight(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: HighlightService_Highlight_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(HighlightServiceServer).Highlight(ctx, req.(*HighlightRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// HighlightService_ServiceDesc is the grpc.ServiceDesc for HighlightService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var HighlightService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "milvus.proto.modelservice.HighlightService",
|
||||
HandlerType: (*HighlightServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Highlight",
|
||||
Handler: _HighlightService_Highlight_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "model_service.proto",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user