diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 2fb203d0e9..a27f184b6c 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -648,10 +648,13 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s resultDataIdx = sIdx maxScore = sScore } else if sScore == maxScore { - sID := typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx) - tmpID := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx) - - if typeutil.ComparePK(sID, tmpID) { + if subSearchIdx == -1 { + // A bad case happens where Knowhere returns distance/score == +/-maxFloat32 + // by mistake. + log.Error("a bad score is returned, something is wrong here!", zap.Float32("score", sScore)) + } else if typeutil.ComparePK( + typeutil.GetPK(subSearchResultData[i].GetIds(), sIdx), + typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)) { subSearchIdx = i resultDataIdx = sIdx maxScore = sScore diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 089461730e..9b3af1d251 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "strconv" "strings" "testing" @@ -1270,6 +1271,68 @@ func TestTaskSearch_selectHighestScoreIndex(t *testing.T) { } }) + t.Run("Integer ID with bad score", func(t *testing.T) { + type args struct { + subSearchResultData []*schemapb.SearchResultData + subSearchNqOffset [][]int64 + cursors []int64 + topk int64 + nq int64 + } + tests := []struct { + description string + args args + + expectedIdx []int + expectedDataIdx []int + }{ + { + description: "reduce 2 subSearchResultData", + args: args{ + subSearchResultData: []*schemapb.SearchResultData{ + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 8, 5, 3, 1}, + }, + }, + }, + Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32}, + Topks: []int64{2, 2, 2}, + }, + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{12, 10, 7, 6, 4, 2}, + }, + }, + }, + Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32}, + Topks: []int64{2, 2, 2}, + }, + }, + subSearchNqOffset: [][]int64{{0, 2, 4}, {0, 2, 4}}, + cursors: []int64{0, 0}, + topk: 2, + nq: 3, + }, + expectedIdx: []int{-1, -1, -1}, + expectedDataIdx: []int{-1, -1, -1}, + }, + } + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + for nqNum := int64(0); nqNum < test.args.nq; nqNum++ { + idx, dataIdx := selectHighestScoreIndex(test.args.subSearchResultData, test.args.subSearchNqOffset, test.args.cursors, nqNum) + assert.Equal(t, test.expectedIdx[nqNum], idx) + assert.Equal(t, test.expectedDataIdx[nqNum], int(dataIdx)) + } + }) + } + }) + t.Run("String ID", func(t *testing.T) { type args struct { subSearchResultData []*schemapb.SearchResultData diff --git a/internal/querynode/result.go b/internal/querynode/result.go index 41b6f2f7ce..59c9fed5fa 100644 --- a/internal/querynode/result.go +++ b/internal/querynode/result.go @@ -198,10 +198,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset maxDistance = distance resultDataIdx = idx } else if distance == maxDistance { - sID := typeutil.GetPK(dataArray[i].GetIds(), idx) - tmpID := typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx) - - if typeutil.ComparePK(sID, tmpID) { + if sel == -1 { + // A bad case happens where knowhere returns distance == +/-maxFloat32 + // by mistake. + log.Error("a bad distance is found, something is wrong here!", zap.Float32("score", distance)) + } else if typeutil.ComparePK( + typeutil.GetPK(dataArray[i].GetIds(), idx), + typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)) { sel = i maxDistance = distance resultDataIdx = idx diff --git a/internal/querynode/result_test.go b/internal/querynode/result_test.go index 4f726813b0..2f86b430b8 100644 --- a/internal/querynode/result_test.go +++ b/internal/querynode/result_test.go @@ -18,6 +18,7 @@ package querynode import ( "context" + "math" "testing" "github.com/stretchr/testify/assert" @@ -424,51 +425,104 @@ func TestResult_selectSearchResultData_int(t *testing.T) { nq int64 qi int64 } - tests := []struct { - name string - args args - want int - }{ - { - args: args{ - dataArray: []*schemapb.SearchResultData{ - { - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{11, 9, 7, 5, 3, 1}, + t.Run("Integer ID", func(t *testing.T) { + tests := []struct { + name string + args args + want int + }{ + { + args: args{ + dataArray: []*schemapb.SearchResultData{ + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 7, 5, 3, 1}, + }, }, }, + Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1}, + Topks: []int64{2, 2, 2}, }, - Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1}, - Topks: []int64{2, 2, 2}, - }, - { - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{12, 10, 8, 6, 4, 2}, + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{12, 10, 8, 6, 4, 2}, + }, }, }, + Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2}, + Topks: []int64{2, 2, 2}, }, - Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2}, - Topks: []int64{2, 2, 2}, }, + resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}}, + offsets: []int64{0, 1}, + topk: 2, + nq: 3, + qi: 0, }, - resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}}, - offsets: []int64{0, 1}, - topk: 2, - nq: 3, - qi: 0, + want: 0, }, - want: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want { - t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want) - } - }) - } + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want { + t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want) + } + }) + } + }) + + t.Run("Integer ID with bad score", func(t *testing.T) { + tests := []struct { + name string + args args + want int + }{ + { + args: args{ + dataArray: []*schemapb.SearchResultData{ + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 7, 5, 3, 1}, + }, + }, + }, + Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32}, + Topks: []int64{2, 2, 2}, + }, + { + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{12, 10, 8, 6, 4, 2}, + }, + }, + }, + Scores: []float32{-math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32, -math.MaxFloat32}, + Topks: []int64{2, 2, 2}, + }, + }, + resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}}, + offsets: []int64{0, 1}, + topk: 2, + nq: 3, + qi: 0, + }, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want { + t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want) + } + }) + } + }) + }