diff --git a/client/index/disk_ann.go b/client/index/disk_ann.go index 4a029b7da8..a3648d5a67 100644 --- a/client/index/disk_ann.go +++ b/client/index/disk_ann.go @@ -33,6 +33,7 @@ func NewDiskANNIndex(metricType MetricType) Index { return &diskANNIndex{ baseIndex: baseIndex{ metricType: metricType, + indexType: DISKANN, }, } } diff --git a/client/index/flat.go b/client/index/flat.go index cc336c23d5..3ebfff4af8 100644 --- a/client/index/flat.go +++ b/client/index/flat.go @@ -33,6 +33,7 @@ func NewFlatIndex(metricType MetricType) Index { return flatIndex{ baseIndex: baseIndex{ metricType: metricType, + indexType: Flat, }, } } @@ -54,6 +55,7 @@ func NewBinFlatIndex(metricType MetricType) Index { return binFlatIndex{ baseIndex: baseIndex{ metricType: metricType, + indexType: BinFlat, }, } } diff --git a/cmd/components/proxy.go b/cmd/components/proxy.go index 5fcda443f8..96ee18ed6a 100644 --- a/cmd/components/proxy.go +++ b/cmd/components/proxy.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -50,6 +51,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { } func (n *Proxy) Prepare() error { + indexparamcheck.ValidateParamTable() return n.svr.Prepare() } diff --git a/internal/core/src/segcore/vector_index_c.cpp b/internal/core/src/segcore/vector_index_c.cpp index b59f91802d..59f2a7130d 100644 --- a/internal/core/src/segcore/vector_index_c.cpp +++ b/internal/core/src/segcore/vector_index_c.cpp @@ -20,6 +20,79 @@ #include "index/IndexFactory.h" #include "pb/index_cgo_msg.pb.h" +CStatus +ValidateIndexParams(const char* index_type, + enum CDataType data_type, + const uint8_t* serialized_index_params, + const uint64_t length) { + try { + auto index_params = + std::make_unique(); + auto res = + index_params->ParseFromArray(serialized_index_params, length); + AssertInfo(res, "Unmarshall index params failed"); + + knowhere::Json json; + + for (size_t i = 0; i < index_params->params_size(); i++) { + auto& param = index_params->params(i); + json[param.key()] = param.value(); + } + + milvus::DataType dataType(static_cast(data_type)); + + knowhere::Status status; + std::string error_msg; + if (dataType == milvus::DataType::VECTOR_BINARY) { + status = knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT) { + status = knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { + status = knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { + status = knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { + status = knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else { + status = knowhere::Status::invalid_args; + } + CStatus cStatus; + if (status == knowhere::Status::success) { + cStatus.error_code = milvus::Success; + cStatus.error_msg = ""; + } else { + cStatus.error_code = milvus::ConfigInvalid; + cStatus.error_msg = strdup(error_msg.c_str()); + } + return cStatus; + } catch (std::exception& e) { + auto cStatus = CStatus(); + cStatus.error_code = milvus::UnexpectedError; + cStatus.error_msg = strdup(e.what()); + return cStatus; + } +} + int GetIndexListSize() { return knowhere::IndexFactory::Instance().GetIndexFeatures().size(); diff --git a/internal/core/src/segcore/vector_index_c.h b/internal/core/src/segcore/vector_index_c.h index 535faddaeb..7e9b8f5239 100644 --- a/internal/core/src/segcore/vector_index_c.h +++ b/internal/core/src/segcore/vector_index_c.h @@ -17,6 +17,12 @@ extern "C" { #include #include "common/type_c.h" +CStatus +ValidateIndexParams(const char* index_type, + enum CDataType data_type, + const uint8_t* index_params, + const uint64_t length); + int GetIndexListSize(); diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index 973a8d7aed..6cf65b6fda 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -34,11 +34,11 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/workerpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index 2b87fd8717..4b754f8679 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -28,11 +28,11 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index ce11d49486..9ffbdc7e6d 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -42,9 +42,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -620,13 +620,13 @@ func TestServer_AlterIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Healthy) t.Run("mmap_unsupported", func(t *testing.T) { - indexParams[0].Value = indexparamcheck.IndexRaftCagra + indexParams[0].Value = "GPU_CAGRA" resp, err := s.AlterIndex(ctx, req) assert.NoError(t, err) assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid) - indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat + indexParams[0].Value = "IVF_FLAT" }) t.Run("param_value_invalied", func(t *testing.T) { diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index a5aed08ee0..f904a708de 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -28,13 +28,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -475,25 +475,18 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro if err := fillDimension(field, indexParams); err != nil { return err } - } else { - // used only for checker, should be deleted after checking - indexParams[IsSparseKey] = "true" } - if err := checker.CheckValidDataType(field); err != nil { + if err := checker.CheckValidDataType(indexType, field); err != nil { log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String())) return err } - if err := checker.CheckTrain(indexParams); err != nil { + if err := checker.CheckTrain(field.DataType, indexParams); err != nil { log.Info("create index with invalid parameters", zap.Error(err)) return err } - if isSparse { - delete(indexParams, IsSparseKey) - } - return nil } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 6bbcc8c075..8d056fadb0 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -35,9 +35,9 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 1d2ccb6f04..939c2c9f6a 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/hookutil" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -48,7 +49,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/index_attr_cache.go b/internal/querynodev2/segments/index_attr_cache.go index df1e0e0647..fd7b55843c 100644 --- a/internal/querynodev2/segments/index_attr_cache.go +++ b/internal/querynodev2/segments/index_attr_cache.go @@ -29,11 +29,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/querynodev2/segments/index_attr_cache_test.go b/internal/querynodev2/segments/index_attr_cache_test.go index 2b88e001c7..dc4a11a70c 100644 --- a/internal/querynodev2/segments/index_attr_cache_test.go +++ b/internal/querynodev2/segments/index_attr_cache_test.go @@ -24,8 +24,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 4b651e12cc..0d04e54e93 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -52,12 +52,12 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/cgo" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index bf36e0d012..79f8206ae1 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -33,11 +33,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/contextutil" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" diff --git a/internal/querynodev2/segments/utils.go b/internal/querynodev2/segments/utils.go index d3276ad8dd..d9ebb961ff 100644 --- a/internal/querynodev2/segments/utils.go +++ b/internal/querynodev2/segments/utils.go @@ -29,12 +29,12 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/contextutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/pkg/util/indexparamcheck/auto_index_checker.go b/internal/util/indexparamcheck/auto_index_checker.go similarity index 59% rename from pkg/util/indexparamcheck/auto_index_checker.go rename to internal/util/indexparamcheck/auto_index_checker.go index cc83f196d2..f56a2887b1 100644 --- a/pkg/util/indexparamcheck/auto_index_checker.go +++ b/internal/util/indexparamcheck/auto_index_checker.go @@ -9,11 +9,11 @@ type AUTOINDEXChecker struct { baseChecker } -func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error { +func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { return nil } -func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } diff --git a/pkg/util/indexparamcheck/base_checker.go b/internal/util/indexparamcheck/base_checker.go similarity index 68% rename from pkg/util/indexparamcheck/base_checker.go rename to internal/util/indexparamcheck/base_checker.go index 6ea600ba40..ed52d320dd 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/internal/util/indexparamcheck/base_checker.go @@ -19,29 +19,17 @@ package indexparamcheck import ( "fmt" "math" - "strings" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type baseChecker struct{} -func (c baseChecker) CheckTrain(params map[string]string) error { - // vector dimension should be checked on collection creation. this is just some basic check - isSparse := false - if val, exist := params[common.IsSparseKey]; exist { - val = strings.ToLower(val) - if val != "true" && val != "false" { - return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val) - } - if val == "true" { - isSparse = true - } - } - if isSparse { +func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if typeutil.IsSparseFloatVectorType(dataType) { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics) } @@ -55,13 +43,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error { } // CheckValidDataType check whether the field data type is supported for the index type -func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { return nil } -func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {} +func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {} -func (c baseChecker) StaticCheck(params map[string]string) error { +func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return errors.New("unsupported index type") } diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/internal/util/indexparamcheck/base_checker_test.go similarity index 83% rename from pkg/util/indexparamcheck/base_checker_test.go rename to internal/util/indexparamcheck/base_checker_test.go index 59a0969d18..99333f5d1f 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/internal/util/indexparamcheck/base_checker_test.go @@ -21,7 +21,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) { } sparseParamsWithoutDim := map[string]string{ Metric: metric.IP, - common.IsSparseKey: "tRue", + common.IsSparseKey: "True", } sparseParamsWrongMetric := map[string]string{ Metric: metric.L2, @@ -42,9 +42,15 @@ func Test_baseChecker_CheckTrain(t *testing.T) { {badSparseParams, false}, } - c := newBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "HNSW" + var err error + if test.params[common.IsSparseKey] == "True" { + err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params) + } else { + err = c.CheckTrain(schemapb.DataType_FloatVector, test.params) + } if test.errIsNil { assert.NoError(t, err) } else { @@ -115,7 +121,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { c := newBaseChecker() for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("FLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { @@ -126,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { func Test_baseChecker_StaticCheck(t *testing.T) { // TODO - assert.Error(t, newBaseChecker().StaticCheck(nil)) + assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil)) } diff --git a/internal/util/indexparamcheck/bin_flat_checker.go b/internal/util/indexparamcheck/bin_flat_checker.go new file mode 100644 index 0000000000..0a4cfb9009 --- /dev/null +++ b/internal/util/indexparamcheck/bin_flat_checker.go @@ -0,0 +1,19 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type binFlatChecker struct { + binaryVectorBaseChecker +} + +func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.binaryVectorBaseChecker.CheckTrain(dataType, params) +} + +func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + return c.staticCheck(params) +} + +func newBinFlatChecker() IndexChecker { + return &binFlatChecker{} +} diff --git a/pkg/util/indexparamcheck/bin_flat_checker_test.go b/internal/util/indexparamcheck/bin_flat_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/bin_flat_checker_test.go rename to internal/util/indexparamcheck/bin_flat_checker_test.go index 9cf4f39394..9f747afb65 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_flat_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -64,9 +65,10 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) { {p7, true}, } - c := newBinFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "BINFLAT" + err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -134,10 +136,10 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go b/internal/util/indexparamcheck/bin_ivf_flat_checker.go similarity index 55% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker.go index c36bc41c1c..df419f2628 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker.go @@ -2,13 +2,15 @@ package indexparamcheck import ( "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) type binIVFFlatChecker struct { binaryVectorBaseChecker } -func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { +func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, BinIvfMetrics) { return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics) } @@ -20,12 +22,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { return nil } -func (c binIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil { +func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.binaryVectorBaseChecker.CheckTrain(dataType, params); err != nil { return err } - return c.StaticCheck(params) + return c.StaticCheck(schemapb.DataType_BinaryVector, params) } func newBinIVFFlatChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/bin_ivf_flat_checker_test.go index 77bda3bb01..cf82b97d84 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -115,9 +116,10 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newBinIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "BIN_IVF_FLAT" + err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -185,10 +187,10 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BIN_IVF_FLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/internal/util/indexparamcheck/binary_vector_base_checker.go similarity index 72% rename from pkg/util/indexparamcheck/binary_vector_base_checker.go rename to internal/util/indexparamcheck/binary_vector_base_checker.go index e73bd8b62e..91a04b2d26 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker.go @@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(dataType, params); err != nil { return err } return c.staticCheck(params) } -func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.GetDataType() != schemapb.DataType_BinaryVector { return fmt.Errorf("binary vector is only supported") } return nil } -func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go b/internal/util/indexparamcheck/binary_vector_base_checker_test.go similarity index 92% rename from pkg/util/indexparamcheck/binary_vector_base_checker_test.go rename to internal/util/indexparamcheck/binary_vector_base_checker_test.go index b52648f793..85942a3fc1 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/binary_vector_base_checker_test.go @@ -67,10 +67,10 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) { }, } - c := newBinaryVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { fieldSchema := &schemapb.FieldSchema{DataType: test.dType} - err := c.CheckValidDataType(fieldSchema) + err := c.CheckValidDataType("BINFLAT", fieldSchema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/bitmap_checker_test.go b/internal/util/indexparamcheck/bitmap_checker_test.go new file mode 100644 index 0000000000..09180fdbb5 --- /dev/null +++ b/internal/util/indexparamcheck/bitmap_checker_test.go @@ -0,0 +1,33 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_BitmapIndexChecker(t *testing.T) { + c := newBITMAPChecker() + + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true})) +} diff --git a/pkg/util/indexparamcheck/bitmap_index_checker.go b/internal/util/indexparamcheck/bitmap_index_checker.go similarity index 78% rename from pkg/util/indexparamcheck/bitmap_index_checker.go rename to internal/util/indexparamcheck/bitmap_index_checker.go index f19943a50e..37eba2f51d 100644 --- a/pkg/util/indexparamcheck/bitmap_index_checker.go +++ b/internal/util/indexparamcheck/bitmap_index_checker.go @@ -11,11 +11,11 @@ type BITMAPChecker struct { scalarIndexChecker } -func (c *BITMAPChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, params) } -func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if field.IsPrimaryKey { return fmt.Errorf("create bitmap index on primary key not supported") } diff --git a/pkg/util/indexparamcheck/cagra_checker.go b/internal/util/indexparamcheck/cagra_checker.go similarity index 84% rename from pkg/util/indexparamcheck/cagra_checker.go rename to internal/util/indexparamcheck/cagra_checker.go index 8f52a1605d..9d4a55214c 100644 --- a/pkg/util/indexparamcheck/cagra_checker.go +++ b/internal/util/indexparamcheck/cagra_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // diskannChecker checks if an diskann index can be built. @@ -10,8 +12,8 @@ type cagraChecker struct { floatVectorBaseChecker } -func (c *cagraChecker) CheckTrain(params map[string]string) error { - err := c.baseChecker.CheckTrain(params) +func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + err := c.baseChecker.CheckTrain(dataType, params) if err != nil { return err } @@ -54,7 +56,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error { return nil } -func (c cagraChecker) StaticCheck(params map[string]string) error { +func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/cagra_checker_test.go b/internal/util/indexparamcheck/cagra_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/cagra_checker_test.go rename to internal/util/indexparamcheck/cagra_checker_test.go index 23a931a12e..4fc1127038 100644 --- a/pkg/util/indexparamcheck/cagra_checker_test.go +++ b/internal/util/indexparamcheck/cagra_checker_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -101,9 +103,13 @@ func Test_cagraChecker_CheckTrain(t *testing.T) { {p14, false}, } - c := newCagraChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_CAGRA") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckTrain(test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go similarity index 64% rename from pkg/util/indexparamcheck/conf_adapter_mgr.go rename to internal/util/indexparamcheck/conf_adapter_mgr.go index f9957a95ee..a746f423ce 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -20,6 +20,8 @@ import ( "sync" "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/util/vecindexmgr" ) type IndexCheckerMgr interface { @@ -34,36 +36,19 @@ type indexCheckerMgrImpl struct { func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) { mgr.once.Do(mgr.registerIndexChecker) - + // Unify the vector index checker + if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) { + return mgr.checkers[IndexVector], nil + } adapter, ok := mgr.checkers[indexType] if ok { return adapter, nil } - return nil, errors.New("Can not find conf adapter: " + indexType) + return nil, errors.New("Can not find index: " + indexType + " , please check") } func (mgr *indexCheckerMgrImpl) registerIndexChecker() { - mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker() - mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker() - mgr.checkers[IndexRaftCagra] = newCagraChecker() - mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker() - mgr.checkers[IndexFaissIDMap] = newFlatChecker() - mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() - mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() - mgr.checkers[IndexScaNN] = newScaNNChecker() - mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() - mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker() - mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker() - mgr.checkers[IndexHNSW] = newHnswChecker() - mgr.checkers[IndexDISKANN] = newDiskannChecker() - mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker() - mgr.checkers[IndexFaissHNSW] = newFloatVectorBaseChecker() - mgr.checkers[IndexFaissHNSWPQ] = newFloatVectorBaseChecker() - mgr.checkers[IndexFaissHNSWSQ] = newFloatVectorBaseChecker() - mgr.checkers[IndexFaissHNSWPRQ] = newFloatVectorBaseChecker() - // WAND doesn't have more index params than sparse inverted index, thus - // using the same checker. - mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker() + mgr.checkers[IndexVector] = newVecIndexChecker() mgr.checkers[IndexINVERTED] = newINVERTEDChecker() mgr.checkers[IndexSTLSORT] = newSTLSORTChecker() mgr.checkers["Asceneding"] = newSTLSORTChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/internal/util/indexparamcheck/conf_adapter_mgr_test.go similarity index 66% rename from pkg/util/indexparamcheck/conf_adapter_mgr_test.go rename to internal/util/indexparamcheck/conf_adapter_mgr_test.go index 6ab9469ee5..08242e924b 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr_test.go @@ -29,52 +29,52 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*flatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*scaNNChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfPQChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfSQChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*binFlatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*binIVFFlatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*hnswChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) } @@ -89,52 +89,52 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { assert.NotEqual(t, nil, err) assert.Equal(t, nil, adapter) - adapter, err = adapterMgr.GetChecker(IndexFaissIDMap) + adapter, err = adapterMgr.GetChecker("FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*flatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat) + adapter, err = adapterMgr.GetChecker("IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexScaNN) + adapter, err = adapterMgr.GetChecker("SCANN") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*scaNNChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) + adapter, err = adapterMgr.GetChecker("IVF_PQ") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfPQChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8) + adapter, err = adapterMgr.GetChecker("IVF_SQ8") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfSQChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap) + adapter, err = adapterMgr.GetChecker("BIN_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*binFlatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat) + adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*binIVFFlatChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) - adapter, err = adapterMgr.GetChecker(IndexHNSW) + adapter, err = adapterMgr.GetChecker("HNSW") assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*hnswChecker) + _, ok = adapter.(*vecIndexChecker) assert.Equal(t, true, ok) } @@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - adapter, err := mgr.GetChecker(IndexHNSW) + adapter, err := mgr.GetChecker("HNSW") assert.NoError(t, err) assert.NotNil(t, adapter) }() diff --git a/pkg/util/indexparamcheck/constraints.go b/internal/util/indexparamcheck/constraints.go similarity index 97% rename from pkg/util/indexparamcheck/constraints.go rename to internal/util/indexparamcheck/constraints.go index 14d374e53c..d1044f37fa 100644 --- a/pkg/util/indexparamcheck/constraints.go +++ b/internal/util/indexparamcheck/constraints.go @@ -65,7 +65,7 @@ var ( CagraBuildAlgoTypes = []string{CargaBuildAlgoIVFPQ, CargaBuildAlgoNNDESCENT} supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const - SparseMetrics = []string{metric.IP} // const + SparseMetrics = []string{metric.IP, metric.BM25} // const ) const ( diff --git a/pkg/util/indexparamcheck/diskann_checker.go b/internal/util/indexparamcheck/diskann_checker.go similarity index 59% rename from pkg/util/indexparamcheck/diskann_checker.go rename to internal/util/indexparamcheck/diskann_checker.go index 3f2401851e..323859b6f0 100644 --- a/pkg/util/indexparamcheck/diskann_checker.go +++ b/internal/util/indexparamcheck/diskann_checker.go @@ -1,11 +1,13 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + // diskannChecker checks if an diskann index can be built. type diskannChecker struct { floatVectorBaseChecker } -func (c diskannChecker) StaticCheck(params map[string]string) error { +func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { return c.staticCheck(params) } diff --git a/pkg/util/indexparamcheck/diskann_checker_test.go b/internal/util/indexparamcheck/diskann_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/diskann_checker_test.go rename to internal/util/indexparamcheck/diskann_checker_test.go index 4fcfdbf019..50d692ea34 100644 --- a/pkg/util/indexparamcheck/diskann_checker_test.go +++ b/internal/util/indexparamcheck/diskann_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -72,9 +73,10 @@ func Test_diskannChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newDiskannChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "DISKANN" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -142,9 +144,9 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) { }, } - c := newDiskannChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/flat_checker.go b/internal/util/indexparamcheck/flat_checker.go similarity index 52% rename from pkg/util/indexparamcheck/flat_checker.go rename to internal/util/indexparamcheck/flat_checker.go index d98db44920..8fe6d59f25 100644 --- a/pkg/util/indexparamcheck/flat_checker.go +++ b/internal/util/indexparamcheck/flat_checker.go @@ -1,10 +1,12 @@ package indexparamcheck +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + type flatChecker struct { floatVectorBaseChecker } -func (c flatChecker) StaticCheck(m map[string]string) error { +func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error { return c.staticCheck(m) } diff --git a/pkg/util/indexparamcheck/flat_checker_test.go b/internal/util/indexparamcheck/flat_checker_test.go similarity index 79% rename from pkg/util/indexparamcheck/flat_checker_test.go rename to internal/util/indexparamcheck/flat_checker_test.go index c22432bc6f..e6991c571e 100644 --- a/pkg/util/indexparamcheck/flat_checker_test.go +++ b/internal/util/indexparamcheck/flat_checker_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -52,9 +54,10 @@ func Test_flatChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "FLAT" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -89,9 +92,10 @@ func Test_flatChecker_StaticCheck(t *testing.T) { }, } - c := newFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { - err := c.StaticCheck(test.params) + test.params[common.IndexTypeKey] = "FLAT" + err := c.StaticCheck(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/internal/util/indexparamcheck/float_vector_base_checker.go similarity index 69% rename from pkg/util/indexparamcheck/float_vector_base_checker.go rename to internal/util/indexparamcheck/float_vector_base_checker.go index 710dfb3a18..b95e9f3b16 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/internal/util/indexparamcheck/float_vector_base_checker.go @@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error { return nil } -func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { +func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.baseChecker.CheckTrain(dataType, params); err != nil { return err } return c.staticCheck(params) } -func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsDenseFloatVectorType(field.GetDataType()) { return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector") } return nil } -func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/float_vector_base_checker_test.go b/internal/util/indexparamcheck/float_vector_base_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/float_vector_base_checker_test.go rename to internal/util/indexparamcheck/float_vector_base_checker_test.go index 7eb0a97d36..7a1bf3c6d2 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/float_vector_base_checker_test.go @@ -63,13 +63,13 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) { }, { dType: schemapb.DataType_BinaryVector, - errIsNil: false, + errIsNil: true, }, } - c := newFloatVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/internal/util/indexparamcheck/hnsw_checker.go similarity index 72% rename from pkg/util/indexparamcheck/hnsw_checker.go rename to internal/util/indexparamcheck/hnsw_checker.go index b5f9e1f2b7..fbc5ef2bb2 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/internal/util/indexparamcheck/hnsw_checker.go @@ -12,7 +12,7 @@ type hnswChecker struct { baseChecker } -func (c hnswChecker) StaticCheck(params map[string]string) error { +func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) { return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) } @@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error { return nil } -func (c hnswChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { +func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { return err } - return c.baseChecker.CheckTrain(params) + return c.baseChecker.CheckTrain(dataType, params) } -func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsVectorType(field.GetDataType()) { return fmt.Errorf("can't build hnsw in not vector type") } return nil } -func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { if typeutil.IsDenseFloatVectorType(dType) { setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) } else if typeutil.IsSparseFloatVectorType(dType) { diff --git a/pkg/util/indexparamcheck/hnsw_checker_test.go b/internal/util/indexparamcheck/hnsw_checker_test.go similarity index 87% rename from pkg/util/indexparamcheck/hnsw_checker_test.go rename to internal/util/indexparamcheck/hnsw_checker_test.go index b911812540..0ed41375ad 100644 --- a/pkg/util/indexparamcheck/hnsw_checker_test.go +++ b/internal/util/indexparamcheck/hnsw_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -88,13 +89,19 @@ func Test_hnswChecker_CheckTrain(t *testing.T) { {p3, true}, {p4, true}, {p5, true}, - {p6, false}, - {p7, false}, + {p6, true}, + {p7, true}, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "HNSW" + var err error + if CheckStrByValues(test.params, common.MetricTypeKey, BinaryVectorMetrics) { + err = c.CheckTrain(schemapb.DataType_BinaryVector, test.params) + } else { + err = c.CheckTrain(schemapb.DataType_FloatVector, test.params) + } if test.errIsNil { assert.NoError(t, err) } else { @@ -162,9 +169,9 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) { }, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { @@ -200,14 +207,14 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) { }, } - c := newHnswChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW") for _, test := range cases { p := map[string]string{ DIM: strconv.Itoa(128), HNSWM: strconv.Itoa(16), EFConstruction: strconv.Itoa(200), } - c.SetDefaultMetricTypeIfNotExist(p, test.dType) + c.SetDefaultMetricTypeIfNotExist(test.dType, p) assert.Equal(t, p[Metric], test.metricType) } } diff --git a/internal/util/indexparamcheck/hybrid_checker_test.go b/internal/util/indexparamcheck/hybrid_checker_test.go new file mode 100644 index 0000000000..4d09f7aca1 --- /dev/null +++ b/internal/util/indexparamcheck/hybrid_checker_test.go @@ -0,0 +1,37 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_HybridIndexChecker(t *testing.T) { + c := newHYBRIDChecker() + + assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{"bitmap_cardinality_limit": "100"})) + + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) + assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{})) + assert.Error(t, c.CheckTrain(schemapb.DataType_Float, map[string]string{"bitmap_cardinality_limit": "0"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_Double, map[string]string{"bitmap_cardinality_limit": "2000"})) +} diff --git a/pkg/util/indexparamcheck/hybrid_index_checker.go b/internal/util/indexparamcheck/hybrid_index_checker.go similarity index 82% rename from pkg/util/indexparamcheck/hybrid_index_checker.go rename to internal/util/indexparamcheck/hybrid_index_checker.go index 9493bccd91..08dfe03943 100644 --- a/pkg/util/indexparamcheck/hybrid_index_checker.go +++ b/internal/util/indexparamcheck/hybrid_index_checker.go @@ -12,15 +12,15 @@ type HYBRIDChecker struct { scalarIndexChecker } -func (c *HYBRIDChecker) CheckTrain(params map[string]string) error { +func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) { return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d", MaxBitmapCardinalityLimit) } - return c.scalarIndexChecker.CheckTrain(params) + return c.scalarIndexChecker.CheckTrain(dataType, params) } -func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { mainType := field.GetDataType() elemType := field.GetElementType() if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) && diff --git a/pkg/util/indexparamcheck/index_checker.go b/internal/util/indexparamcheck/index_checker.go similarity index 77% rename from pkg/util/indexparamcheck/index_checker.go rename to internal/util/indexparamcheck/index_checker.go index 1c11280898..610ddffc2c 100644 --- a/pkg/util/indexparamcheck/index_checker.go +++ b/internal/util/indexparamcheck/index_checker.go @@ -21,8 +21,8 @@ import ( ) type IndexChecker interface { - CheckTrain(map[string]string) error - CheckValidDataType(field *schemapb.FieldSchema) error - SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType) - StaticCheck(map[string]string) error + CheckTrain(schemapb.DataType, map[string]string) error + CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error + SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string) + StaticCheck(schemapb.DataType, map[string]string) error } diff --git a/pkg/util/indexparamcheck/index_checker_test.go b/internal/util/indexparamcheck/index_checker_test.go similarity index 100% rename from pkg/util/indexparamcheck/index_checker_test.go rename to internal/util/indexparamcheck/index_checker_test.go diff --git a/pkg/util/indexparamcheck/index_type.go b/internal/util/indexparamcheck/index_type.go similarity index 60% rename from pkg/util/indexparamcheck/index_type.go rename to internal/util/indexparamcheck/index_type.go index fa634bc1d4..5bc78f8ec2 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/internal/util/indexparamcheck/index_type.go @@ -15,6 +15,7 @@ import ( "fmt" "strconv" + "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" ) @@ -23,31 +24,7 @@ type IndexType = string // IndexType definitions const ( - // vector index - IndexGpuBF IndexType = "GPU_BRUTE_FORCE" - IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT" - IndexRaftIvfPQ IndexType = "GPU_IVF_PQ" - IndexRaftCagra IndexType = "GPU_CAGRA" - IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE" - IndexFaissIDMap IndexType = "FLAT" // no index is built. - IndexFaissIvfFlat IndexType = "IVF_FLAT" - IndexFaissIvfPQ IndexType = "IVF_PQ" - IndexScaNN IndexType = "SCANN" - IndexFaissIvfSQ8 IndexType = "IVF_SQ8" - IndexFaissBinIDMap IndexType = "BIN_FLAT" - IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT" - IndexHNSW IndexType = "HNSW" - IndexDISKANN IndexType = "DISKANN" - IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX" - IndexSparseWand IndexType = "SPARSE_WAND" - // For temporary use, will be removed in the future. - // 1. All Index related param check will be moved to Knowhere recently. - // 2. FAISS_HNSW_xxx will be rename to HNSW_xxx after QA test. We keep the original name for comparison purpose. - // TODO: @liliu-z @foxspy - IndexFaissHNSW IndexType = "FAISS_HNSW_FLAT" - IndexFaissHNSWPQ IndexType = "FAISS_HNSW_PQ" - IndexFaissHNSWSQ IndexType = "FAISS_HNSW_SQ" - IndexFaissHNSWPRQ IndexType = "FAISS_HNSW_PRQ" + IndexVector IndexType = "VECINDEX" // scalar index IndexSTLSORT IndexType = "STL_SORT" @@ -66,28 +43,12 @@ func IsScalarIndexType(indexType IndexType) bool { } func IsGpuIndex(indexType IndexType) bool { - return indexType == IndexGpuBF || - indexType == IndexRaftIvfFlat || - indexType == IndexRaftIvfPQ || - indexType == IndexRaftCagra + return vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(indexType) } // IsVectorMmapIndex check if the vector index can be mmaped func IsVectorMmapIndex(indexType IndexType) bool { - return indexType == IndexFaissIDMap || - indexType == IndexFaissIvfFlat || - indexType == IndexFaissIvfPQ || - indexType == IndexFaissIvfSQ8 || - indexType == IndexFaissBinIDMap || - indexType == IndexFaissBinIvfFlat || - indexType == IndexHNSW || - indexType == IndexFaissHNSW || - indexType == IndexFaissHNSWPQ || - indexType == IndexFaissHNSWSQ || - indexType == IndexFaissHNSWPRQ || - indexType == IndexScaNN || - indexType == IndexSparseInverted || - indexType == IndexSparseWand + return vecindexmgr.GetVecIndexMgrInstance().IsMMapSupported(indexType) } func IsOffsetCacheSupported(indexType IndexType) bool { @@ -95,7 +56,7 @@ func IsOffsetCacheSupported(indexType IndexType) bool { } func IsDiskIndex(indexType IndexType) bool { - return indexType == IndexDISKANN + return vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) } func IsScalarMmapIndex(indexType IndexType) bool { diff --git a/pkg/util/indexparamcheck/index_type_test.go b/internal/util/indexparamcheck/index_type_test.go similarity index 100% rename from pkg/util/indexparamcheck/index_type_test.go rename to internal/util/indexparamcheck/index_type_test.go diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go similarity index 69% rename from pkg/util/indexparamcheck/inverted_checker.go rename to internal/util/indexparamcheck/inverted_checker.go index 8d6893c100..83a0c65cdd 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -12,11 +12,11 @@ type INVERTEDChecker struct { scalarIndexChecker } -func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, params) } -func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { dType := field.GetDataType() if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && !typeutil.IsArrayType(dType) { diff --git a/internal/util/indexparamcheck/inverted_checker_test.go b/internal/util/indexparamcheck/inverted_checker_test.go new file mode 100644 index 0000000000..68afa3903b --- /dev/null +++ b/internal/util/indexparamcheck/inverted_checker_test.go @@ -0,0 +1,25 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_INVERTEDIndexChecker(t *testing.T) { + c := newINVERTEDChecker() + + assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array})) + + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) +} diff --git a/internal/util/indexparamcheck/ivf_base_checker.go b/internal/util/indexparamcheck/ivf_base_checker.go new file mode 100644 index 0000000000..b3b8fdd53b --- /dev/null +++ b/internal/util/indexparamcheck/ivf_base_checker.go @@ -0,0 +1,28 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type ivfBaseChecker struct { + floatVectorBaseChecker +} + +func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { + return errOutOfRange(NLIST, MinNList, MaxNList) + } + + // skip check number of rows + + return c.floatVectorBaseChecker.staticCheck(params) +} + +func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.floatVectorBaseChecker.CheckTrain(dataType, params) +} + +func newIVFBaseChecker() IndexChecker { + return &ivfBaseChecker{} +} diff --git a/pkg/util/indexparamcheck/ivf_base_checker_test.go b/internal/util/indexparamcheck/ivf_base_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/ivf_base_checker_test.go rename to internal/util/indexparamcheck/ivf_base_checker_test.go index 4a379038dd..dfe115f2eb 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker_test.go +++ b/internal/util/indexparamcheck/ivf_base_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -70,9 +71,10 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newIVFBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "IVF_FLAT" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -140,9 +142,9 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_pq_checker.go b/internal/util/indexparamcheck/ivf_pq_checker.go similarity index 86% rename from pkg/util/indexparamcheck/ivf_pq_checker.go rename to internal/util/indexparamcheck/ivf_pq_checker.go index 4c35f193c4..c38486e1c9 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker.go +++ b/internal/util/indexparamcheck/ivf_pq_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // ivfPQChecker checks if a IVF_PQ index can be built. @@ -11,8 +13,8 @@ type ivfPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *ivfPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/ivf_pq_checker_test.go b/internal/util/indexparamcheck/ivf_pq_checker_test.go similarity index 92% rename from pkg/util/indexparamcheck/ivf_pq_checker_test.go rename to internal/util/indexparamcheck/ivf_pq_checker_test.go index 4a22d45542..7bb539671b 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_pq_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -141,9 +142,11 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) { {p7, false}, } + // c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ") c := newIVFPQChecker() for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "IVF_PQ" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -211,9 +214,9 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker.go b/internal/util/indexparamcheck/ivf_sq_checker.go similarity index 78% rename from pkg/util/indexparamcheck/ivf_sq_checker.go rename to internal/util/indexparamcheck/ivf_sq_checker.go index fc1a2204f5..7597277002 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker.go +++ b/internal/util/indexparamcheck/ivf_sq_checker.go @@ -2,6 +2,8 @@ package indexparamcheck import ( "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // ivfSQChecker checks if a IVF_SQ index can be built. @@ -22,11 +24,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error { } // CheckTrain returns true if the index can be built with the specific index parameters. -func (c *ivfSQChecker) CheckTrain(params map[string]string) error { +func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { if err := c.checkNBits(params); err != nil { return err } - return c.ivfBaseChecker.CheckTrain(params) + return c.ivfBaseChecker.CheckTrain(dataType, params) } func newIVFSQChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker_test.go b/internal/util/indexparamcheck/ivf_sq_checker_test.go similarity index 90% rename from pkg/util/indexparamcheck/ivf_sq_checker_test.go rename to internal/util/indexparamcheck/ivf_sq_checker_test.go index 9478623fe8..ee37eea78d 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_sq_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -78,7 +79,6 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) { }{ {validParams, true}, {validParamsWithNBits, true}, - {paramsWithInvalidNBits, false}, {invalidIVFParamsMin(), false}, {invalidIVFParamsMax(), false}, {p1, true}, @@ -90,9 +90,10 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newIVFSQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "IVF_SQ" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -160,9 +161,9 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) { }, } - c := newIVFSQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ") for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker.go b/internal/util/indexparamcheck/raft_brute_force_checker.go similarity index 61% rename from pkg/util/indexparamcheck/raft_brute_force_checker.go rename to internal/util/indexparamcheck/raft_brute_force_checker.go index 38872da7ec..68c0e482b2 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker.go @@ -1,14 +1,18 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) type raftBruteForceChecker struct { floatVectorBaseChecker } // raftBrustForceChecker checks if a Brute_Force index can be built. -func (c raftBruteForceChecker) CheckTrain(params map[string]string) error { - if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil { +func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.floatVectorBaseChecker.CheckTrain(dataType, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go b/internal/util/indexparamcheck/raft_brute_force_checker_test.go similarity index 71% rename from pkg/util/indexparamcheck/raft_brute_force_checker_test.go rename to internal/util/indexparamcheck/raft_brute_force_checker_test.go index ce037bc4dc..6f5fd9a0ec 100644 --- a/pkg/util/indexparamcheck/raft_brute_force_checker_test.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker_test.go @@ -6,6 +6,9 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -52,9 +55,14 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newRaftBruteForceChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_BRUTE_FORCE") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "GPU_BRUTE_FORCE" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go b/internal/util/indexparamcheck/raft_ivf_flat_checker.go similarity index 74% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker.go index 9f11803e9b..8b48f6f28d 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker.go @@ -1,6 +1,10 @@ package indexparamcheck -import "fmt" +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) // raftIVFChecker checks if a RAFT_IVF_Flat index can be built. type raftIVFFlatChecker struct { @@ -8,8 +12,8 @@ type raftIVFFlatChecker struct { } // CheckTrain checks if ivf-flat index can be built with the specific index parameters. -func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go similarity index 83% rename from pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_flat_checker_test.go index 3d64f83039..d2debb90b7 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -84,9 +86,14 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) { {p9, false}, } - c := newRaftIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "GPU_IVF_FLAT" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -154,9 +161,13 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) { }, } - c := newRaftIVFFlatChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/internal/util/indexparamcheck/raft_ivf_pq_checker.go similarity index 88% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker.go index 2457619118..1ea35a8c7d 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built. @@ -11,8 +13,8 @@ type raftIVFPQChecker struct { } // CheckTrain checks if ivf-pq index can be built with the specific index parameters. -func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil { return err } if !CheckStrByValues(params, Metric, RaftMetrics) { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go similarity index 88% rename from pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go rename to internal/util/indexparamcheck/raft_ivf_pq_checker_test.go index 8c882900e9..fcbda59063 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -144,9 +146,14 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { {p9, false}, } - c := newRaftIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "GPU_IVF_PQ" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -214,9 +221,13 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) { }, } - c := newRaftIVFPQChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ") + if c == nil { + log.Error("can not get index checker instance, please enable GPU and rerun it") + return + } for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/scalar_index_checker.go b/internal/util/indexparamcheck/scalar_index_checker.go new file mode 100644 index 0000000000..a1272ae388 --- /dev/null +++ b/internal/util/indexparamcheck/scalar_index_checker.go @@ -0,0 +1,11 @@ +package indexparamcheck + +import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + +type scalarIndexChecker struct { + baseChecker +} + +func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return nil +} diff --git a/pkg/util/indexparamcheck/scalar_index_checker_test.go b/internal/util/indexparamcheck/scalar_index_checker_test.go similarity index 53% rename from pkg/util/indexparamcheck/scalar_index_checker_test.go rename to internal/util/indexparamcheck/scalar_index_checker_test.go index eb3ae669e2..faf8ea2419 100644 --- a/pkg/util/indexparamcheck/scalar_index_checker_test.go +++ b/internal/util/indexparamcheck/scalar_index_checker_test.go @@ -4,9 +4,11 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestCheckIndexValid(t *testing.T) { scalarIndexChecker := &scalarIndexChecker{} - assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{})) + assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, map[string]string{})) } diff --git a/pkg/util/indexparamcheck/scann_checker.go b/internal/util/indexparamcheck/scann_checker.go similarity index 78% rename from pkg/util/indexparamcheck/scann_checker.go rename to internal/util/indexparamcheck/scann_checker.go index eecf2ded64..94aa5258e4 100644 --- a/pkg/util/indexparamcheck/scann_checker.go +++ b/internal/util/indexparamcheck/scann_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // scaNNChecker checks if a SCANN index can be built. @@ -11,8 +13,8 @@ type scaNNChecker struct { } // CheckTrain checks if SCANN index can be built with the specific index parameters. -func (c *scaNNChecker) CheckTrain(params map[string]string) error { - if err := c.ivfBaseChecker.CheckTrain(params); err != nil { +func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil { return err } diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/internal/util/indexparamcheck/scann_checker_test.go similarity index 91% rename from pkg/util/indexparamcheck/scann_checker_test.go rename to internal/util/indexparamcheck/scann_checker_test.go index 4f7014c6fd..2dd32f160c 100644 --- a/pkg/util/indexparamcheck/scann_checker_test.go +++ b/internal/util/indexparamcheck/scann_checker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" ) @@ -87,9 +88,10 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) { {p7, false}, } - c := newScaNNChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("SCANN") for _, test := range cases { - err := c.CheckTrain(test.params) + test.params[common.IndexTypeKey] = "SCANN" + err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -159,7 +161,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) + err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go similarity index 81% rename from pkg/util/indexparamcheck/sparse_float_vector_base_checker.go rename to internal/util/indexparamcheck/sparse_float_vector_base_checker.go index 9cd8921e31..6038f3f45c 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go +++ b/internal/util/indexparamcheck/sparse_float_vector_base_checker.go @@ -12,7 +12,7 @@ import ( // sparse vector don't check for dim, but baseChecker does, thus not including baseChecker type sparseFloatVectorBaseChecker struct{} -func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { if !CheckStrByValues(params, Metric, SparseMetrics) { return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics) } @@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro return nil } -func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error { +func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { dropRatioBuildStr, exist := params[SparseDropRatioBuild] if exist { dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64) @@ -48,14 +48,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error return nil } -func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsSparseFloatVectorType(field.GetDataType()) { return fmt.Errorf("only sparse float vector is supported for the specified index tpye") } return nil } -func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) { +func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) } diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go b/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go similarity index 52% rename from pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go rename to internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go index 2fb558f4fb..c05dbdb69c 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go @@ -6,84 +6,95 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" ) func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) { validParams := map[string]string{ - Metric: "IP", + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "IP", } invalidParams := map[string]string{ - Metric: "L2", + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "L2", } - c := newSparseFloatVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX") t.Run("valid metric", func(t *testing.T) { - err := c.StaticCheck(validParams) + err := c.StaticCheck(schemapb.DataType_SparseFloatVector, validParams) assert.NoError(t, err) }) t.Run("invalid metric", func(t *testing.T) { - err := c.StaticCheck(invalidParams) + err := c.StaticCheck(schemapb.DataType_SparseFloatVector, invalidParams) assert.Error(t, err) }) } func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) { validParams := map[string]string{ + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "IP", SparseDropRatioBuild: "0.5", BM25K1: "1.5", BM25B: "0.5", } invalidDropRatio := map[string]string{ + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "IP", SparseDropRatioBuild: "1.5", } invalidBM25K1 := map[string]string{ - BM25K1: "3.5", + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "IP", + BM25K1: "3.5", } invalidBM25B := map[string]string{ - BM25B: "1.5", + common.IndexTypeKey: "SPARSE_INVERTED_INDEX", + Metric: "IP", + BM25B: "1.5", } - c := newSparseFloatVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX") t.Run("valid params", func(t *testing.T) { - err := c.CheckTrain(validParams) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, validParams) assert.NoError(t, err) }) t.Run("invalid drop ratio", func(t *testing.T) { - err := c.CheckTrain(invalidDropRatio) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidDropRatio) assert.Error(t, err) }) t.Run("invalid BM25K1", func(t *testing.T) { - err := c.CheckTrain(invalidBM25K1) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25K1) assert.Error(t, err) }) t.Run("invalid BM25B", func(t *testing.T) { - err := c.CheckTrain(invalidBM25B) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25B) assert.Error(t, err) }) } func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) { - c := newSparseFloatVectorBaseChecker() + c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX") t.Run("valid data type", func(t *testing.T) { field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector} - err := c.CheckValidDataType(field) + err := c.CheckValidDataType("SPARSE_WAND", field) assert.NoError(t, err) }) t.Run("invalid data type", func(t *testing.T) { field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector} - err := c.CheckValidDataType(field) + err := c.CheckValidDataType("SPARSE_WAND", field) assert.Error(t, err) }) } diff --git a/pkg/util/indexparamcheck/sparse_inverted_index_checker.go b/internal/util/indexparamcheck/sparse_inverted_index_checker.go similarity index 100% rename from pkg/util/indexparamcheck/sparse_inverted_index_checker.go rename to internal/util/indexparamcheck/sparse_inverted_index_checker.go diff --git a/pkg/util/indexparamcheck/stl_sort_checker.go b/internal/util/indexparamcheck/stl_sort_checker.go similarity index 64% rename from pkg/util/indexparamcheck/stl_sort_checker.go rename to internal/util/indexparamcheck/stl_sort_checker.go index 4b3441ad6d..7681bea7a6 100644 --- a/pkg/util/indexparamcheck/stl_sort_checker.go +++ b/internal/util/indexparamcheck/stl_sort_checker.go @@ -12,11 +12,11 @@ type STLSORTChecker struct { scalarIndexChecker } -func (c *STLSORTChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, params) } -func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsArithmetic(field.GetDataType()) { return fmt.Errorf("STL_SORT are only supported on numeric field") } diff --git a/internal/util/indexparamcheck/stl_sort_checker_test.go b/internal/util/indexparamcheck/stl_sort_checker_test.go new file mode 100644 index 0000000000..7bb21f95d7 --- /dev/null +++ b/internal/util/indexparamcheck/stl_sort_checker_test.go @@ -0,0 +1,22 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_STLSORTIndexChecker(t *testing.T) { + c := newSTLSORTChecker() + + assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/trie_checker.go b/internal/util/indexparamcheck/trie_checker.go similarity index 64% rename from pkg/util/indexparamcheck/trie_checker.go rename to internal/util/indexparamcheck/trie_checker.go index 002014e420..5666e4d6a7 100644 --- a/pkg/util/indexparamcheck/trie_checker.go +++ b/internal/util/indexparamcheck/trie_checker.go @@ -12,11 +12,11 @@ type TRIEChecker struct { scalarIndexChecker } -func (c *TRIEChecker) CheckTrain(params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(params) +func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, params) } -func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error { +func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsStringType(field.GetDataType()) { return fmt.Errorf("TRIE are only supported on varchar field") } diff --git a/internal/util/indexparamcheck/trie_checker_test.go b/internal/util/indexparamcheck/trie_checker_test.go new file mode 100644 index 0000000000..fb81c90b2c --- /dev/null +++ b/internal/util/indexparamcheck/trie_checker_test.go @@ -0,0 +1,23 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_TrieIndexChecker(t *testing.T) { + c := newTRIEChecker() + + assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, map[string]string{})) + + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) + + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) +} diff --git a/pkg/util/indexparamcheck/utils.go b/internal/util/indexparamcheck/utils.go similarity index 64% rename from pkg/util/indexparamcheck/utils.go rename to internal/util/indexparamcheck/utils.go index adca93aeb3..c392b7d4e1 100644 --- a/pkg/util/indexparamcheck/utils.go +++ b/internal/util/indexparamcheck/utils.go @@ -20,7 +20,10 @@ import ( "fmt" "strconv" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // CheckIntByRange check if the data corresponding to the key is in the range of [min, max]. @@ -69,3 +72,30 @@ func setDefaultIfNotExist(params map[string]string, key string, defaultValue str params[key] = defaultValue } } + +func CheckAutoIndexHelper(key string, m map[string]string, dtype schemapb.DataType) { + indexType, ok := m[common.IndexTypeKey] + if !ok { + panic(fmt.Sprintf("%s invalid, index type not found", key)) + } + + checker, err := GetIndexCheckerMgrInstance().GetChecker(indexType) + if err != nil { + panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType)) + } + + if err := checker.StaticCheck(dtype, m); err != nil { + panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error())) + } +} + +func CheckAutoIndexConfig() { + autoIndexCfg := ¶mtable.Get().AutoIndexConfig + CheckAutoIndexHelper(autoIndexCfg.IndexParams.Key, autoIndexCfg.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + CheckAutoIndexHelper(autoIndexCfg.BinaryIndexParams.Key, autoIndexCfg.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector) + CheckAutoIndexHelper(autoIndexCfg.SparseIndexParams.Key, autoIndexCfg.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector) +} + +func ValidateParamTable() { + CheckAutoIndexConfig() +} diff --git a/internal/util/indexparamcheck/utils_test.go b/internal/util/indexparamcheck/utils_test.go new file mode 100644 index 0000000000..2135c58895 --- /dev/null +++ b/internal/util/indexparamcheck/utils_test.go @@ -0,0 +1,269 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func Test_CheckIntByRange(t *testing.T) { + params := map[string]string{ + "1": strconv.Itoa(1), + "2": strconv.Itoa(2), + "3": strconv.Itoa(3), + "s1": "s1", + "s2": "s2", + "s3": "s3", + } + + cases := []struct { + params map[string]string + key string + min int + max int + want bool + }{ + {params, "1", 0, 4, true}, + {params, "2", 0, 4, true}, + {params, "3", 0, 4, true}, + {params, "1", 4, 5, false}, + {params, "2", 4, 5, false}, + {params, "3", 4, 5, false}, + {params, "4", 0, 4, false}, + {params, "5", 0, 4, false}, + {params, "6", 0, 4, false}, + {params, "s1", 0, 4, false}, + {params, "s2", 0, 4, false}, + {params, "s3", 0, 4, false}, + {params, "s4", 0, 4, false}, + {params, "s5", 0, 4, false}, + {params, "s6", 0, 4, false}, + } + + for _, test := range cases { + if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want { + t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want) + } + } +} + +func Test_CheckStrByValues(t *testing.T) { + params := map[string]string{ + "1": strconv.Itoa(1), + "2": strconv.Itoa(2), + "3": strconv.Itoa(3), + } + + cases := []struct { + params map[string]string + key string + container []string + want bool + }{ + {params, "1", []string{"1", "2", "3"}, true}, + {params, "2", []string{"1", "2", "3"}, true}, + {params, "3", []string{"1", "2", "3"}, true}, + {params, "1", []string{"4", "5", "6"}, false}, + {params, "2", []string{"4", "5", "6"}, false}, + {params, "3", []string{"4", "5", "6"}, false}, + {params, "1", []string{}, false}, + {params, "2", []string{}, false}, + {params, "3", []string{}, false}, + {params, "4", []string{"1", "2", "3"}, false}, + {params, "5", []string{"1", "2", "3"}, false}, + {params, "6", []string{"1", "2", "3"}, false}, + {params, "4", []string{"4", "5", "6"}, false}, + {params, "5", []string{"4", "5", "6"}, false}, + {params, "6", []string{"4", "5", "6"}, false}, + {params, "4", []string{}, false}, + {params, "5", []string{}, false}, + {params, "6", []string{}, false}, + } + + for _, test := range cases { + if got := CheckStrByValues(test.params, test.key, test.container); got != test.want { + t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want) + } + } +} + +func Test_CheckAutoIndex(t *testing.T) { + t.Run("index type not found", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"M": 30}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.Panics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + }) + + t.Run("unsupported index type", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.Panics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + }) + + t.Run("normal case, hnsw", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.COSINE, metricType) + }) + + t.Run("normal case, binary vector", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`) + p := ¶mtable.AutoIndexConfig{ + BinaryIndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.binary.build", + }, + } + p.BinaryIndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector) + }) + metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.HAMMING, metricType) + }) + + t.Run("normal case, sparse vector", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`) + p := ¶mtable.AutoIndexConfig{ + SparseIndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.sparse.build", + }, + } + p.SparseIndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector) + }) + metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.IP, metricType) + }) + + t.Run("normal case, ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.COSINE, metricType) + }) + + t.Run("normal case, ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.COSINE, metricType) + }) + + t.Run("normal case, diskann", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.COSINE, metricType) + }) + + t.Run("normal case, bin flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.HAMMING, metricType) + }) + + t.Run("normal case, bin ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`) + p := ¶mtable.AutoIndexConfig{ + IndexParams: paramtable.ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) + assert.NotPanics(t, func() { + CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, metric.HAMMING, metricType) + }) +} diff --git a/internal/util/indexparamcheck/vector_index_checker.go b/internal/util/indexparamcheck/vector_index_checker.go new file mode 100644 index 0000000000..61ef76ebc2 --- /dev/null +++ b/internal/util/indexparamcheck/vector_index_checker.go @@ -0,0 +1,112 @@ +package indexparamcheck + +/* +#cgo pkg-config: milvus_core + +#include // free +#include "segcore/vector_index_c.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" + "github.com/milvus-io/milvus/internal/util/vecindexmgr" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type vecIndexChecker struct { + baseChecker +} + +// HandleCStatus deals with the error returned from CGO +func HandleCStatus(status *C.CStatus) error { + if status.error_code == 0 { + return nil + } + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + + return fmt.Errorf("%s", errorMsg) +} + +func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { + if typeutil.IsDenseFloatVectorType(dataType) { + if !CheckStrByValues(params, Metric, FloatVectorMetrics) { + return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], FloatVectorMetrics) + } + } else if typeutil.IsSparseFloatVectorType(dataType) { + if !CheckStrByValues(params, Metric, SparseMetrics) { + return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics) + } + } else if typeutil.IsBinaryVectorType(dataType) { + if !CheckStrByValues(params, Metric, BinaryVectorMetrics) { + return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinaryVectorMetrics) + } + } + + indexType, exist := params[common.IndexTypeKey] + + if !exist { + return fmt.Errorf("no indexType is specified") + } + + if !vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) { + return fmt.Errorf("indexType %s is not supported", indexType) + } + + protoIndexParams := &indexcgopb.IndexParams{ + Params: make([]*commonpb.KeyValuePair, 0), + } + + for key, value := range params { + protoIndexParams.Params = append(protoIndexParams.Params, &commonpb.KeyValuePair{Key: key, Value: value}) + } + + indexParamsBlob, err := proto.Marshal(protoIndexParams) + if err != nil { + return fmt.Errorf("failed to marshal index params: %s", err) + } + + var status C.CStatus + + cIndexType := C.CString(indexType) + cDataType := uint32(dataType) + status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob))) + C.free(unsafe.Pointer(cIndexType)) + + return HandleCStatus(&status) +} + +func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, params); err != nil { + return err + } + return c.baseChecker.CheckTrain(dataType, params) +} + +func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { + if !typeutil.IsVectorType(field.GetDataType()) { + return fmt.Errorf("index %s only supports vector data type", indexType) + } + if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) { + return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())]) + } + return nil +} + +func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { + paramtable.SetDefaultMetricTypeIfNotExist(dType, params) +} + +func newVecIndexChecker() IndexChecker { + return &vecIndexChecker{} +} diff --git a/internal/util/indexparamcheck/vector_index_checker_test.go b/internal/util/indexparamcheck/vector_index_checker_test.go new file mode 100644 index 0000000000..e06f95d047 --- /dev/null +++ b/internal/util/indexparamcheck/vector_index_checker_test.go @@ -0,0 +1,132 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestVecIndexChecker_StaticCheck(t *testing.T) { + checker := newVecIndexChecker() + + tests := []struct { + name string + dataType schemapb.DataType + params map[string]string + wantErr bool + }{ + { + name: "Valid IVF_FLAT index", + dataType: schemapb.DataType_FloatVector, + params: map[string]string{ + "index_type": "IVF_FLAT", + "metric_type": "L2", + "nlist": "1024", + }, + wantErr: false, + }, + { + name: "Invalid index type", + dataType: schemapb.DataType_FloatVector, + params: map[string]string{ + "index_type": "INVALID_INDEX", + }, + wantErr: true, + }, + { + name: "Missing index type", + dataType: schemapb.DataType_FloatVector, + params: map[string]string{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checker.StaticCheck(tt.dataType, tt.params) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVecIndexChecker_CheckValidDataType(t *testing.T) { + checker := newVecIndexChecker() + + tests := []struct { + name string + indexType IndexType + field *schemapb.FieldSchema + wantErr bool + }{ + { + name: "Valid float vector", + indexType: "IVF_FLAT", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + }, + wantErr: false, + }, + { + name: "Invalid data type", + indexType: "IVF_FLAT", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int64, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checker.CheckValidDataType(tt.indexType, tt.field) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestVecIndexChecker_SetDefaultMetricTypeIfNotExist(t *testing.T) { + checker := newVecIndexChecker() + + tests := []struct { + name string + dataType schemapb.DataType + params map[string]string + expectedType string + }{ + { + name: "Float vector", + dataType: schemapb.DataType_FloatVector, + params: map[string]string{}, + expectedType: FloatVectorDefaultMetricType, + }, + { + name: "Binary vector", + dataType: schemapb.DataType_BinaryVector, + params: map[string]string{}, + expectedType: BinaryVectorDefaultMetricType, + }, + { + name: "Existing metric type", + dataType: schemapb.DataType_FloatVector, + params: map[string]string{"metric_type": "IP"}, + expectedType: "IP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checker.SetDefaultMetricTypeIfNotExist(tt.dataType, tt.params) + assert.Equal(t, tt.expectedType, tt.params["metric_type"]) + }) + } +} diff --git a/pkg/util/indexparamcheck/bin_flat_checker.go b/pkg/util/indexparamcheck/bin_flat_checker.go deleted file mode 100644 index 2e0b813c38..0000000000 --- a/pkg/util/indexparamcheck/bin_flat_checker.go +++ /dev/null @@ -1,17 +0,0 @@ -package indexparamcheck - -type binFlatChecker struct { - binaryVectorBaseChecker -} - -func (c binFlatChecker) CheckTrain(params map[string]string) error { - return c.binaryVectorBaseChecker.CheckTrain(params) -} - -func (c binFlatChecker) StaticCheck(params map[string]string) error { - return c.staticCheck(params) -} - -func newBinFlatChecker() IndexChecker { - return &binFlatChecker{} -} diff --git a/pkg/util/indexparamcheck/bitmap_checker_test.go b/pkg/util/indexparamcheck/bitmap_checker_test.go deleted file mode 100644 index 95d74f85bc..0000000000 --- a/pkg/util/indexparamcheck/bitmap_checker_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_BitmapIndexChecker(t *testing.T) { - c := newBITMAPChecker() - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true})) -} diff --git a/pkg/util/indexparamcheck/hybrid_checker_test.go b/pkg/util/indexparamcheck/hybrid_checker_test.go deleted file mode 100644 index 733adc2922..0000000000 --- a/pkg/util/indexparamcheck/hybrid_checker_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_HybridIndexChecker(t *testing.T) { - c := newHYBRIDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckTrain(map[string]string{})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"})) - assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"})) -} diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go deleted file mode 100644 index baecd97dd1..0000000000 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_INVERTEDIndexChecker(t *testing.T) { - c := newINVERTEDChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) -} diff --git a/pkg/util/indexparamcheck/ivf_base_checker.go b/pkg/util/indexparamcheck/ivf_base_checker.go deleted file mode 100644 index 9b8a3e2e04..0000000000 --- a/pkg/util/indexparamcheck/ivf_base_checker.go +++ /dev/null @@ -1,26 +0,0 @@ -package indexparamcheck - -type ivfBaseChecker struct { - floatVectorBaseChecker -} - -func (c ivfBaseChecker) StaticCheck(params map[string]string) error { - if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { - return errOutOfRange(NLIST, MinNList, MaxNList) - } - - // skip check number of rows - - return c.floatVectorBaseChecker.staticCheck(params) -} - -func (c ivfBaseChecker) CheckTrain(params map[string]string) error { - if err := c.StaticCheck(params); err != nil { - return err - } - return c.floatVectorBaseChecker.CheckTrain(params) -} - -func newIVFBaseChecker() IndexChecker { - return &ivfBaseChecker{} -} diff --git a/pkg/util/indexparamcheck/scalar_index_checker.go b/pkg/util/indexparamcheck/scalar_index_checker.go deleted file mode 100644 index 9c372f4034..0000000000 --- a/pkg/util/indexparamcheck/scalar_index_checker.go +++ /dev/null @@ -1,9 +0,0 @@ -package indexparamcheck - -type scalarIndexChecker struct { - baseChecker -} - -func (c scalarIndexChecker) CheckTrain(params map[string]string) error { - return nil -} diff --git a/pkg/util/indexparamcheck/stl_sort_checker_test.go b/pkg/util/indexparamcheck/stl_sort_checker_test.go deleted file mode 100644 index 771a51cd32..0000000000 --- a/pkg/util/indexparamcheck/stl_sort_checker_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_STLSORTIndexChecker(t *testing.T) { - c := newSTLSORTChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/indexparamcheck/trie_checker_test.go b/pkg/util/indexparamcheck/trie_checker_test.go deleted file mode 100644 index 3e1eaea1c5..0000000000 --- a/pkg/util/indexparamcheck/trie_checker_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package indexparamcheck - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" -) - -func Test_TrieIndexChecker(t *testing.T) { - c := newTRIEChecker() - - assert.NoError(t, c.CheckTrain(map[string]string{})) - - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) -} diff --git a/pkg/util/indexparamcheck/utils_test.go b/pkg/util/indexparamcheck/utils_test.go deleted file mode 100644 index 7798bcd896..0000000000 --- a/pkg/util/indexparamcheck/utils_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package indexparamcheck - -import ( - "strconv" - "testing" -) - -func Test_CheckIntByRange(t *testing.T) { - params := map[string]string{ - "1": strconv.Itoa(1), - "2": strconv.Itoa(2), - "3": strconv.Itoa(3), - "s1": "s1", - "s2": "s2", - "s3": "s3", - } - - cases := []struct { - params map[string]string - key string - min int - max int - want bool - }{ - {params, "1", 0, 4, true}, - {params, "2", 0, 4, true}, - {params, "3", 0, 4, true}, - {params, "1", 4, 5, false}, - {params, "2", 4, 5, false}, - {params, "3", 4, 5, false}, - {params, "4", 0, 4, false}, - {params, "5", 0, 4, false}, - {params, "6", 0, 4, false}, - {params, "s1", 0, 4, false}, - {params, "s2", 0, 4, false}, - {params, "s3", 0, 4, false}, - {params, "s4", 0, 4, false}, - {params, "s5", 0, 4, false}, - {params, "s6", 0, 4, false}, - } - - for _, test := range cases { - if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want { - t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want) - } - } -} - -func Test_CheckStrByValues(t *testing.T) { - params := map[string]string{ - "1": strconv.Itoa(1), - "2": strconv.Itoa(2), - "3": strconv.Itoa(3), - } - - cases := []struct { - params map[string]string - key string - container []string - want bool - }{ - {params, "1", []string{"1", "2", "3"}, true}, - {params, "2", []string{"1", "2", "3"}, true}, - {params, "3", []string{"1", "2", "3"}, true}, - {params, "1", []string{"4", "5", "6"}, false}, - {params, "2", []string{"4", "5", "6"}, false}, - {params, "3", []string{"4", "5", "6"}, false}, - {params, "1", []string{}, false}, - {params, "2", []string{}, false}, - {params, "3", []string{}, false}, - {params, "4", []string{"1", "2", "3"}, false}, - {params, "5", []string{"1", "2", "3"}, false}, - {params, "6", []string{"1", "2", "3"}, false}, - {params, "4", []string{"4", "5", "6"}, false}, - {params, "5", []string{"4", "5", "6"}, false}, - {params, "6", []string{"4", "5", "6"}, false}, - {params, "4", []string{}, false}, - {params, "5", []string{}, false}, - {params, "6", []string{}, false}, - } - - for _, test := range cases { - if got := CheckStrByValues(test.params, test.key, test.container); got != test.want { - t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want) - } - } -} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 135b3b9963..58fd0dd8fd 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -24,12 +24,13 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // ///////////////////////////////////////////////////////////////////////////// // --- common --- -type autoIndexConfig struct { +type AutoIndexConfig struct { Enable ParamItem `refreshable:"true"` EnableOptimize ParamItem `refreshable:"true"` EnableResultLimitCheck ParamItem `refreshable:"true"` @@ -60,7 +61,7 @@ const ( DefaultBitmapCardinalityLimit = 100 ) -func (p *autoIndexConfig) init(base *BaseTable) { +func (p *AutoIndexConfig) init(base *BaseTable) { p.Enable = ParamItem{ Key: "autoIndex.enable", Version: "2.2.0", @@ -157,7 +158,7 @@ func (p *autoIndexConfig) init(base *BaseTable) { } p.AutoIndexTuningConfig.Init(base.mgr) - p.panicIfNotValidAndSetDefaultMetricType(base.mgr) + p.SetDefaultMetricType(base.mgr) p.ScalarAutoIndexEnable = ParamItem{ Key: "scalarAutoIndex.enable", @@ -244,37 +245,47 @@ func (p *autoIndexConfig) init(base *BaseTable) { p.ScalarBoolIndexType.Init(base.mgr) } -func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricType(mgr *config.Manager) { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr) +// SetDefaultMetricType The config check logic has been moved to internal package; only set defulat metric here +func (p *AutoIndexConfig) SetDefaultMetricType(mgr *config.Manager) { + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) + p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr) } -func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) { +func setDefaultIfNotExist(params map[string]string, key string, defaultValue string) { + _, exist := params[key] + if !exist { + params[key] = defaultValue + } +} + +const ( + FloatVectorDefaultMetricType = metric.COSINE + SparseFloatVectorDefaultMetricType = metric.IP + BinaryVectorDefaultMetricType = metric.HAMMING +) + +func SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) { + if typeutil.IsDenseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) + } else if typeutil.IsSparseFloatVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType) + } else if typeutil.IsBinaryVectorType(dType) { + setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) + } +} + +func (p *AutoIndexConfig) SetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) { if m == nil { panic(fmt.Sprintf("%s invalid, should be json format", key)) } - indexType, ok := m[common.IndexTypeKey] - if !ok { - panic(fmt.Sprintf("%s invalid, index type not found", key)) - } - - checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) - if err != nil { - panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType)) - } - - checker.SetDefaultMetricTypeIfNotExist(m, dtype) - - if err := checker.StaticCheck(m); err != nil { - panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error())) - } + SetDefaultMetricTypeIfNotExist(dtype, m) p.reset(key, m, mgr) } -func (p *autoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) { +func (p *AutoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) { ret, err := funcutil.MapToJSON(m) if err != nil { panic(fmt.Sprintf("%s: convert to json failed, parameters invalid, error: %s", key, err.Error())) diff --git a/pkg/util/paramtable/autoindex_param_test.go b/pkg/util/paramtable/autoindex_param_test.go index 231c8377e7..25bdfa63dc 100644 --- a/pkg/util/paramtable/autoindex_param_test.go +++ b/pkg/util/paramtable/autoindex_param_test.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( @@ -134,180 +133,16 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { t.Run("not in json format", func(t *testing.T) { mgr := config.NewManager() mgr.SetConfig("autoIndex.params.build", "not in json format") - p := &autoIndexConfig{ + p := &AutoIndexConfig{ IndexParams: ParamItem{ Key: "autoIndex.params.build", }, } p.IndexParams.Init(mgr) assert.Panics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) + p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) }) }) - - t.Run("index type not found", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"M": 30}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.Panics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - }) - - t.Run("unsupported index type", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.Panics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - }) - - t.Run("normal case, hnsw", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, binary vector", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`) - p := &autoIndexConfig{ - BinaryIndexParams: ParamItem{ - Key: "autoIndex.params.binary.build", - }, - } - p.BinaryIndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr) - }) - metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, sparse vector", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`) - p := &autoIndexConfig{ - SparseIndexParams: ParamItem{ - Key: "autoIndex.params.sparse.build", - }, - } - p.SparseIndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr) - }) - metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.SparseFloatVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, ivf flat", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, ivf flat", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, diskann", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, bin flat", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) - }) - - t.Run("normal case, bin ivf flat", func(t *testing.T) { - mgr := config.NewManager() - mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`) - p := &autoIndexConfig{ - IndexParams: ParamItem{ - Key: "autoIndex.params.build", - }, - } - p.IndexParams.Init(mgr) - assert.NotPanics(t, func() { - p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr) - }) - metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] - assert.True(t, exist) - assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) - }) } func TestScalarAutoIndexParams_build(t *testing.T) { diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index cf9f4b3247..c8ffa4babd 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -65,7 +65,7 @@ type ComponentParam struct { CommonCfg commonConfig QuotaConfig quotaConfig - AutoIndexConfig autoIndexConfig + AutoIndexConfig AutoIndexConfig GpuConfig gpuConfig TraceCfg traceConfig diff --git a/tests/go_client/testcases/index_test.go b/tests/go_client/testcases/index_test.go index ce8c2faee0..a0602c5b3b 100644 --- a/tests/go_client/testcases/index_test.go +++ b/tests/go_client/testcases/index_test.go @@ -607,7 +607,7 @@ func TestCreateIndexJsonField(t *testing.T) { // create vector index on json field idx := index.NewSCANNIndex(entity.L2, 8, false) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idx).WithIndexName("json_index")) - common.CheckErr(t, err, false, "data type should be FloatVector, Float16Vector or BFloat16Vector") + common.CheckErr(t, err, false, "index SCANN only supports vector data type") // create scalar index on json field type scalarIndexError struct { @@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) { if field.DataType == entity.FieldTypeArray { // create vector index _, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index")) - common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector") + common.CheckErr(t, err1, false, "index SCANN only supports vector data type") // create scalar index _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx)) @@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) { for _, drb := range []float64{-0.3, 1.3} { idxInverted := index.NewSparseInvertedIndex(entity.IP, drb) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted)) - common.CheckErr(t, err, false, "must be in range [0, 1)") + common.CheckErr(t, err, false, "Out of range in json: param 'drop_ratio_build'") idxWand := index.NewSparseWANDIndex(entity.IP, drb) _, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand)) - common.CheckErr(t, err1, false, "must be in range [0, 1)") + common.CheckErr(t, err1, false, "Out of range in json: param 'drop_ratio_build'") } } @@ -944,20 +944,22 @@ func TestCreateVectorIndexScalarField(t *testing.T) { // create float vector index on scalar field for _, idx := range hp.GenAllFloatIndex(entity.COSINE) { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx)) - common.CheckErr(t, err, false, "can't build hnsw in not vector type", - "data type should be FloatVector, Float16Vector or BFloat16Vector") + expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idx.IndexType()) + common.CheckErr(t, err, false, expErrorMsg) } // create binary vector index on scalar field for _, idxBinary := range []index.Index{index.NewBinFlatIndex(entity.IP), index.NewBinIvfFlatIndex(entity.COSINE, 64)} { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxBinary)) - common.CheckErr(t, err, false, "binary vector is only supported") + expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxBinary.IndexType()) + common.CheckErr(t, err, false, expErrorMsg) } // create sparse vector index on scalar field for _, idxSparse := range []index.Index{index.NewSparseInvertedIndex(entity.IP, 0.2), index.NewSparseWANDIndex(entity.IP, 0.3)} { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxSparse)) - common.CheckErr(t, err, false, "only sparse float vector is supported for the specified index") + expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxSparse.IndexType()) + common.CheckErr(t, err, false, expErrorMsg) } } } @@ -972,7 +974,7 @@ func TestCreateIndexInvalidParams(t *testing.T) { _, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true)) // invalid IvfFlat nlist [1, 65536] - errMsg := "nlist out of range: [1, 65536]" + errMsg := "Out of range in json: param 'nlist'" for _, invalidNlist := range []int{0, -1, 65536 + 1} { // IvfFlat idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist) @@ -997,7 +999,7 @@ func TestCreateIndexInvalidParams(t *testing.T) { // IvfFlat idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq)) - common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]") + common.CheckErr(t, err, false, "Out of range in json: param 'nbits'") } idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8) @@ -1009,13 +1011,13 @@ func TestCreateIndexInvalidParams(t *testing.T) { // IvfFlat idxHnsw := index.NewHNSWIndex(entity.L2, invalidM, 96) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw)) - common.CheckErr(t, err, false, "M out of range: [1, 2048]") + common.CheckErr(t, err, false, "Out of range in json: param 'M'") } for _, invalidEfConstruction := range []int{0, 2147483647 + 1} { // IvfFlat idxHnsw := index.NewHNSWIndex(entity.L2, 8, invalidEfConstruction) _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw)) - common.CheckErr(t, err, false, "efConstruction out of range: [1, 2147483647]") + common.CheckErr(t, err, false, "Out of range in json: param 'efConstruction'", "integer value out of range, key: 'efConstruction'") } } diff --git a/tests/integration/import/import_test.go b/tests/integration/import/import_test.go index 85797dbd5b..415489eeed 100644 --- a/tests/integration/import/import_test.go +++ b/tests/integration/import/import_test.go @@ -33,10 +33,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/importutilv2" + "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/tests/integration" @@ -66,7 +66,7 @@ func (s *BulkInsertSuite) SetupTest() { s.autoID = false s.vecType = schemapb.DataType_FloatVector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 } @@ -225,29 +225,29 @@ func (s *BulkInsertSuite) TestMultiFileTypes() { s.fileType = fileType s.vecType = schemapb.DataType_BinaryVector - s.indexType = indexparamcheck.IndexFaissBinIvfFlat + s.indexType = "BIN_IVF_FLAT" s.metricType = metric.HAMMING s.run() s.vecType = schemapb.DataType_FloatVector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() s.vecType = schemapb.DataType_Float16Vector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() s.vecType = schemapb.DataType_BFloat16Vector - s.indexType = indexparamcheck.IndexHNSW + s.indexType = "HNSW" s.metricType = metric.L2 s.run() // TODO: not support numpy for SparseFloatVector by now if fileType != importutilv2.Numpy { s.vecType = schemapb.DataType_SparseFloatVector - s.indexType = indexparamcheck.IndexSparseWand + s.indexType = "SPARSE_WAND" s.metricType = metric.IP s.run() } diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go index 666cc2d15a..15da7d7885 100644 --- a/tests/integration/util_index.go +++ b/tests/integration/util_index.go @@ -26,23 +26,22 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( - IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat - IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ - IndexFaissIDMap = indexparamcheck.IndexFaissIDMap - IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat - IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ - IndexScaNN = indexparamcheck.IndexScaNN - IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8 - IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap - IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat - IndexHNSW = indexparamcheck.IndexHNSW - IndexDISKANN = indexparamcheck.IndexDISKANN - IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted - IndexSparseWand = indexparamcheck.IndexSparseWand + IndexRaftIvfFlat = "GPU_IVF_FLAT" + IndexRaftIvfPQ = "GPU_IVF_PQ" + IndexFaissIDMap = "FLAT" + IndexFaissIvfFlat = "IVF_FLAT" + IndexFaissIvfPQ = "IVF_PQ" + IndexScaNN = "SCANN" + IndexFaissIvfSQ8 = "IVF_SQ8" + IndexFaissBinIDMap = "BIN_FLAT" + IndexFaissBinIvfFlat = "BIN_IVF_FLAT" + IndexHNSW = "HNSW" + IndexDISKANN = "DISKANN" + IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX" + IndexSparseWand = "SPARSE_WAND" ) func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) { diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index cf7aa38e81..5d0136f4de 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -2103,6 +2103,8 @@ def gen_simple_index(): continue elif ct.all_index_types[i] in ct.sparse_support: continue + elif ct.all_index_types[i] in ct.gpu_support: + continue dic = {"index_type": ct.all_index_types[i], "metric_type": "L2"} dic.update({"params": ct.default_all_indexes_params[i]}) index_params.append(dic) diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index a6210ff70c..94193fe6fb 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -244,6 +244,7 @@ default_all_search_params_params = [{}, {"nprobe": 32}, {"nprobe": 32}, {"nprobe Handler_type = ["GRPC", "HTTP"] binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"] sparse_support = ["SPARSE_INVERTED_INDEX", "SPARSE_WAND"] +gpu_support = ["GPU_IVF_FLAT", "GPU_IVF_PQ"] default_L0_metric = "COSINE" float_metrics = ["L2", "IP", "COSINE"] binary_metrics = ["JACCARD", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"] diff --git a/tests/python_client/utils/util_pymilvus.py b/tests/python_client/utils/util_pymilvus.py index 7f334d7fb5..44a7533c79 100644 --- a/tests/python_client/utils/util_pymilvus.py +++ b/tests/python_client/utils/util_pymilvus.py @@ -57,6 +57,8 @@ default_index_params = [ def create_target_index(index, field_name): index["field_name"] = field_name +def gpu_support(): + return ["GPU_IVF_FLAT", "GPU_IVF_PQ"] def binary_support(): return ["BIN_FLAT", "BIN_IVF_FLAT"] @@ -764,6 +766,8 @@ def gen_simple_index(): for i in range(len(all_index_types)): if all_index_types[i] in binary_support(): continue + if all_index_types[i] in gpu_support(): + continue dic = {"index_type": all_index_types[i], "metric_type": "L2"} dic.update({"params": default_index_params[i]}) index_params.append(dic)