From a98c79b6a61390c28ada1aa998dacfddb8bd5663 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Thu, 18 May 2023 11:07:26 +0800 Subject: [PATCH] Fix flat index can be created with invalid metric type (#24180) Signed-off-by: longjiquan --- pkg/util/indexparamcheck/conf_adapter_mgr.go | 2 +- .../indexparamcheck/conf_adapter_mgr_test.go | 4 +- pkg/util/indexparamcheck/flat_checker.go | 9 +++ pkg/util/indexparamcheck/flat_checker_test.go | 68 +++++++++++++++++++ 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 pkg/util/indexparamcheck/flat_checker.go create mode 100644 pkg/util/indexparamcheck/flat_checker_test.go diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/pkg/util/indexparamcheck/conf_adapter_mgr.go index a3dcb719e1..5b9d5e491b 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr.go @@ -45,7 +45,7 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro func (mgr *indexCheckerMgrImpl) registerIndexChecker() { mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker() mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker() - mgr.checkers[IndexFaissIDMap] = newBaseChecker() + mgr.checkers[IndexFaissIDMap] = newFlatChecker() mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go index a6dd583e36..3d801c9fd3 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go @@ -32,7 +32,7 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*baseChecker) + _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) @@ -86,7 +86,7 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*baseChecker) + _, ok = adapter.(*flatChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) diff --git a/pkg/util/indexparamcheck/flat_checker.go b/pkg/util/indexparamcheck/flat_checker.go new file mode 100644 index 0000000000..eea107df02 --- /dev/null +++ b/pkg/util/indexparamcheck/flat_checker.go @@ -0,0 +1,9 @@ +package indexparamcheck + +type flatChecker struct { + floatVectorBaseChecker +} + +func newFlatChecker() IndexChecker { + return &flatChecker{} +} diff --git a/pkg/util/indexparamcheck/flat_checker_test.go b/pkg/util/indexparamcheck/flat_checker_test.go new file mode 100644 index 0000000000..77fd8ca825 --- /dev/null +++ b/pkg/util/indexparamcheck/flat_checker_test.go @@ -0,0 +1,68 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_flatChecker_CheckTrain(t *testing.T) { + + p1 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: L2, + } + p2 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: IP, + } + p3 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: COSINE, + } + + p4 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: HAMMING, + } + p5 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: JACCARD, + } + p6 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: TANIMOTO, + } + p7 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: SUBSTRUCTURE, + } + p8 := map[string]string{ + DIM: strconv.Itoa(128), + Metric: SUPERSTRUCTURE, + } + cases := []struct { + params map[string]string + errIsNil bool + }{ + {p1, true}, + {p2, true}, + {p3, true}, + {p4, false}, + {p5, false}, + {p6, false}, + {p7, false}, + {p8, false}, + } + + c := newFlatChecker() + for _, test := range cases { + err := c.CheckTrain(test.params) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +}