From 8d0cc4226c3cb69b71c61857c2fc00aeb9d25e64 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Mon, 31 Oct 2022 10:13:34 +0800 Subject: [PATCH] Fix IVF_SQ nbits check (#20183) Signed-off-by: longjiquan Signed-off-by: longjiquan --- internal/util/indexparamcheck/conf_adapter.go | 14 +++++++++++- .../util/indexparamcheck/conf_adapter_test.go | 22 ++++++++++++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/internal/util/indexparamcheck/conf_adapter.go b/internal/util/indexparamcheck/conf_adapter.go index 2dcbd5765f..b1bbb3a793 100644 --- a/internal/util/indexparamcheck/conf_adapter.go +++ b/internal/util/indexparamcheck/conf_adapter.go @@ -254,9 +254,21 @@ type IVFSQConfAdapter struct { IVFConfAdapter } +func (adapter *IVFSQConfAdapter) checkNBits(params map[string]string) bool { + // cgo will set this key to DefaultNBits (8), which is the only value Milvus supports. + _, exist := params[NBITS] + if exist { + // 8 is the only supported nbits. + return CheckIntByRange(params, NBITS, DefaultNBits, DefaultNBits) + } + return true +} + // CheckTrain returns true if the index can be built with the specific index parameters. func (adapter *IVFSQConfAdapter) CheckTrain(params map[string]string) bool { - params[NBITS] = strconv.Itoa(DefaultNBits) + if !adapter.checkNBits(params) { + return false + } return adapter.IVFConfAdapter.CheckTrain(params) } diff --git a/internal/util/indexparamcheck/conf_adapter_test.go b/internal/util/indexparamcheck/conf_adapter_test.go index d5649af576..73aedd90c0 100644 --- a/internal/util/indexparamcheck/conf_adapter_test.go +++ b/internal/util/indexparamcheck/conf_adapter_test.go @@ -168,17 +168,29 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) { } func TestIVFSQConfAdapter_CheckTrain(t *testing.T) { - validParams := map[string]string{ - DIM: strconv.Itoa(128), - NLIST: strconv.Itoa(100), - NBITS: strconv.Itoa(8), - Metric: L2, + getValidParams := func(withNBits bool) map[string]string { + validParams := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(100), + NBITS: strconv.Itoa(8), + Metric: L2, + } + if withNBits { + validParams[NBITS] = strconv.Itoa(DefaultNBits) + } + return validParams } + validParams := getValidParams(false) + validParamsWithNBits := getValidParams(true) + paramsWithInvalidNBits := getValidParams(false) + paramsWithInvalidNBits[NBITS] = strconv.Itoa(DefaultNBits + 1) cases := []struct { params map[string]string want bool }{ {validParams, true}, + {validParamsWithNBits, true}, + {paramsWithInvalidNBits, false}, {invalidIVFParamsMin(), false}, {invalidIVFParamsMax(), false}, }