mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
enhance: add an unify vector index config checker (#36844)
issue: #34298 Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
parent
eeb67a3845
commit
d7b2ffe5aa
@ -33,6 +33,7 @@ func NewDiskANNIndex(metricType MetricType) Index {
|
|||||||
return &diskANNIndex{
|
return &diskANNIndex{
|
||||||
baseIndex: baseIndex{
|
baseIndex: baseIndex{
|
||||||
metricType: metricType,
|
metricType: metricType,
|
||||||
|
indexType: DISKANN,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,6 +33,7 @@ func NewFlatIndex(metricType MetricType) Index {
|
|||||||
return flatIndex{
|
return flatIndex{
|
||||||
baseIndex: baseIndex{
|
baseIndex: baseIndex{
|
||||||
metricType: metricType,
|
metricType: metricType,
|
||||||
|
indexType: Flat,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,6 +55,7 @@ func NewBinFlatIndex(metricType MetricType) Index {
|
|||||||
return binFlatIndex{
|
return binFlatIndex{
|
||||||
baseIndex: baseIndex{
|
baseIndex: baseIndex{
|
||||||
metricType: metricType,
|
metricType: metricType,
|
||||||
|
indexType: BinFlat,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy"
|
grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy"
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
@ -50,6 +51,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *Proxy) Prepare() error {
|
func (n *Proxy) Prepare() error {
|
||||||
|
indexparamcheck.ValidateParamTable()
|
||||||
return n.svr.Prepare()
|
return n.svr.Prepare()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,79 @@
|
|||||||
#include "index/IndexFactory.h"
|
#include "index/IndexFactory.h"
|
||||||
#include "pb/index_cgo_msg.pb.h"
|
#include "pb/index_cgo_msg.pb.h"
|
||||||
|
|
||||||
|
CStatus
|
||||||
|
ValidateIndexParams(const char* index_type,
|
||||||
|
enum CDataType data_type,
|
||||||
|
const uint8_t* serialized_index_params,
|
||||||
|
const uint64_t length) {
|
||||||
|
try {
|
||||||
|
auto index_params =
|
||||||
|
std::make_unique<milvus::proto::indexcgo::IndexParams>();
|
||||||
|
auto res =
|
||||||
|
index_params->ParseFromArray(serialized_index_params, length);
|
||||||
|
AssertInfo(res, "Unmarshall index params failed");
|
||||||
|
|
||||||
|
knowhere::Json json;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < index_params->params_size(); i++) {
|
||||||
|
auto& param = index_params->params(i);
|
||||||
|
json[param.key()] = param.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
milvus::DataType dataType(static_cast<milvus::DataType>(data_type));
|
||||||
|
|
||||||
|
knowhere::Status status;
|
||||||
|
std::string error_msg;
|
||||||
|
if (dataType == milvus::DataType::VECTOR_BINARY) {
|
||||||
|
status = knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
|
||||||
|
index_type,
|
||||||
|
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||||
|
json,
|
||||||
|
error_msg);
|
||||||
|
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
|
||||||
|
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||||
|
index_type,
|
||||||
|
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||||
|
json,
|
||||||
|
error_msg);
|
||||||
|
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
|
||||||
|
status = knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
|
||||||
|
index_type,
|
||||||
|
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||||
|
json,
|
||||||
|
error_msg);
|
||||||
|
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
|
||||||
|
status = knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
|
||||||
|
index_type,
|
||||||
|
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||||
|
json,
|
||||||
|
error_msg);
|
||||||
|
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
|
||||||
|
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||||
|
index_type,
|
||||||
|
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||||
|
json,
|
||||||
|
error_msg);
|
||||||
|
} else {
|
||||||
|
status = knowhere::Status::invalid_args;
|
||||||
|
}
|
||||||
|
CStatus cStatus;
|
||||||
|
if (status == knowhere::Status::success) {
|
||||||
|
cStatus.error_code = milvus::Success;
|
||||||
|
cStatus.error_msg = "";
|
||||||
|
} else {
|
||||||
|
cStatus.error_code = milvus::ConfigInvalid;
|
||||||
|
cStatus.error_msg = strdup(error_msg.c_str());
|
||||||
|
}
|
||||||
|
return cStatus;
|
||||||
|
} catch (std::exception& e) {
|
||||||
|
auto cStatus = CStatus();
|
||||||
|
cStatus.error_code = milvus::UnexpectedError;
|
||||||
|
cStatus.error_msg = strdup(e.what());
|
||||||
|
return cStatus;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
GetIndexListSize() {
|
GetIndexListSize() {
|
||||||
return knowhere::IndexFactory::Instance().GetIndexFeatures().size();
|
return knowhere::IndexFactory::Instance().GetIndexFeatures().size();
|
||||||
|
|||||||
@ -17,6 +17,12 @@ extern "C" {
|
|||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include "common/type_c.h"
|
#include "common/type_c.h"
|
||||||
|
|
||||||
|
CStatus
|
||||||
|
ValidateIndexParams(const char* index_type,
|
||||||
|
enum CDataType data_type,
|
||||||
|
const uint8_t* index_params,
|
||||||
|
const uint64_t length);
|
||||||
|
|
||||||
int
|
int
|
||||||
GetIndexListSize();
|
GetIndexListSize();
|
||||||
|
|
||||||
|
|||||||
@ -34,11 +34,11 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/metrics"
|
"github.com/milvus-io/milvus/pkg/metrics"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
|||||||
@ -28,11 +28,11 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/metrics"
|
"github.com/milvus-io/milvus/pkg/metrics"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
|||||||
@ -42,9 +42,9 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -620,13 +620,13 @@ func TestServer_AlterIndex(t *testing.T) {
|
|||||||
s.stateCode.Store(commonpb.StateCode_Healthy)
|
s.stateCode.Store(commonpb.StateCode_Healthy)
|
||||||
|
|
||||||
t.Run("mmap_unsupported", func(t *testing.T) {
|
t.Run("mmap_unsupported", func(t *testing.T) {
|
||||||
indexParams[0].Value = indexparamcheck.IndexRaftCagra
|
indexParams[0].Value = "GPU_CAGRA"
|
||||||
|
|
||||||
resp, err := s.AlterIndex(ctx, req)
|
resp, err := s.AlterIndex(ctx, req)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
|
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
|
||||||
|
|
||||||
indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat
|
indexParams[0].Value = "IVF_FLAT"
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("param_value_invalied", func(t *testing.T) {
|
t.Run("param_value_invalied", func(t *testing.T) {
|
||||||
|
|||||||
@ -28,13 +28,13 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
@ -475,25 +475,18 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
|
|||||||
if err := fillDimension(field, indexParams); err != nil {
|
if err := fillDimension(field, indexParams); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// used only for checker, should be deleted after checking
|
|
||||||
indexParams[IsSparseKey] = "true"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checker.CheckValidDataType(field); err != nil {
|
if err := checker.CheckValidDataType(indexType, field); err != nil {
|
||||||
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
|
log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checker.CheckTrain(indexParams); err != nil {
|
if err := checker.CheckTrain(field.DataType, indexParams); err != nil {
|
||||||
log.Info("create index with invalid parameters", zap.Error(err))
|
log.Info("create index with invalid parameters", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isSparse {
|
|
||||||
delete(indexParams, IsSparseKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -35,9 +35,9 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/mocks"
|
"github.com/milvus-io/milvus/internal/mocks"
|
||||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
|||||||
@ -40,6 +40,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
@ -48,7 +49,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
|||||||
@ -29,11 +29,11 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -52,12 +52,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
|
"github.com/milvus-io/milvus/internal/querynodev2/segments/state"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
"github.com/milvus-io/milvus/internal/util/cgo"
|
"github.com/milvus-io/milvus/internal/util/cgo"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/metrics"
|
"github.com/milvus-io/milvus/pkg/metrics"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||||
|
|||||||
@ -33,11 +33,11 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
|||||||
@ -29,12 +29,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
"github.com/milvus-io/milvus/internal/querycoordv2/params"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
|
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
|||||||
@ -9,11 +9,11 @@ type AUTOINDEXChecker struct {
|
|||||||
baseChecker
|
baseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error {
|
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *AUTOINDEXChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,29 +19,17 @@ package indexparamcheck
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type baseChecker struct{}
|
type baseChecker struct{}
|
||||||
|
|
||||||
func (c baseChecker) CheckTrain(params map[string]string) error {
|
func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
// vector dimension should be checked on collection creation. this is just some basic check
|
if typeutil.IsSparseFloatVectorType(dataType) {
|
||||||
isSparse := false
|
|
||||||
if val, exist := params[common.IsSparseKey]; exist {
|
|
||||||
val = strings.ToLower(val)
|
|
||||||
if val != "true" && val != "false" {
|
|
||||||
return fmt.Errorf("invalid is_sparse value: %s, must be true or false", val)
|
|
||||||
}
|
|
||||||
if val == "true" {
|
|
||||||
isSparse = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isSparse {
|
|
||||||
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||||
return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics)
|
return fmt.Errorf("metric type not found or not supported for sparse float vectors, supported: %v", SparseMetrics)
|
||||||
}
|
}
|
||||||
@ -55,13 +43,13 @@ func (c baseChecker) CheckTrain(params map[string]string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckValidDataType check whether the field data type is supported for the index type
|
// CheckValidDataType check whether the field data type is supported for the index type
|
||||||
func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string, dType schemapb.DataType) {}
|
func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {}
|
||||||
|
|
||||||
func (c baseChecker) StaticCheck(params map[string]string) error {
|
func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return errors.New("unsupported index type")
|
return errors.New("unsupported index type")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
sparseParamsWithoutDim := map[string]string{
|
sparseParamsWithoutDim := map[string]string{
|
||||||
Metric: metric.IP,
|
Metric: metric.IP,
|
||||||
common.IsSparseKey: "tRue",
|
common.IsSparseKey: "True",
|
||||||
}
|
}
|
||||||
sparseParamsWrongMetric := map[string]string{
|
sparseParamsWrongMetric := map[string]string{
|
||||||
Metric: metric.L2,
|
Metric: metric.L2,
|
||||||
@ -42,9 +42,15 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
|
|||||||
{badSparseParams, false},
|
{badSparseParams, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "HNSW"
|
||||||
|
var err error
|
||||||
|
if test.params[common.IsSparseKey] == "True" {
|
||||||
|
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params)
|
||||||
|
} else {
|
||||||
|
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
|
}
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -115,7 +121,7 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
|||||||
c := newBaseChecker()
|
c := newBaseChecker()
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||||
err := c.CheckValidDataType(fieldSchema)
|
err := c.CheckValidDataType("FLAT", fieldSchema)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -126,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
|||||||
|
|
||||||
func Test_baseChecker_StaticCheck(t *testing.T) {
|
func Test_baseChecker_StaticCheck(t *testing.T) {
|
||||||
// TODO
|
// TODO
|
||||||
assert.Error(t, newBaseChecker().StaticCheck(nil))
|
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil))
|
||||||
}
|
}
|
||||||
19
internal/util/indexparamcheck/bin_flat_checker.go
Normal file
19
internal/util/indexparamcheck/bin_flat_checker.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
|
type binFlatChecker struct {
|
||||||
|
binaryVectorBaseChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c binFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
return c.binaryVectorBaseChecker.CheckTrain(dataType, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c binFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
return c.staticCheck(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinFlatChecker() IndexChecker {
|
||||||
|
return &binFlatChecker{}
|
||||||
|
}
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,9 +65,10 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, true},
|
{p7, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBinFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "BINFLAT"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -134,10 +136,10 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBinFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||||
err := c.CheckValidDataType(fieldSchema)
|
err := c.CheckValidDataType("BINFLAT", fieldSchema)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -2,13 +2,15 @@ package indexparamcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type binIVFFlatChecker struct {
|
type binIVFFlatChecker struct {
|
||||||
binaryVectorBaseChecker
|
binaryVectorBaseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
|
func (c binIVFFlatChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
|
if !CheckStrByValues(params, Metric, BinIvfMetrics) {
|
||||||
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics)
|
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinIvfMetrics)
|
||||||
}
|
}
|
||||||
@ -20,12 +22,12 @@ func (c binIVFFlatChecker) StaticCheck(params map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c binIVFFlatChecker) CheckTrain(params map[string]string) error {
|
func (c binIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil {
|
if err := c.binaryVectorBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.StaticCheck(params)
|
return c.StaticCheck(schemapb.DataType_BinaryVector, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBinIVFFlatChecker() IndexChecker {
|
func newBinIVFFlatChecker() IndexChecker {
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -115,9 +116,10 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBinIVFFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "BIN_IVF_FLAT"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -185,10 +187,10 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBinIVFFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||||
err := c.CheckValidDataType(fieldSchema)
|
err := c.CheckValidDataType("BIN_IVF_FLAT", fieldSchema)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -19,22 +19,22 @@ func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error {
|
func (c binaryVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.staticCheck(params)
|
return c.staticCheck(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c binaryVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if field.GetDataType() != schemapb.DataType_BinaryVector {
|
if field.GetDataType() != schemapb.DataType_BinaryVector {
|
||||||
return fmt.Errorf("binary vector is only supported")
|
return fmt.Errorf("binary vector is only supported")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
|
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,10 +67,10 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newBinaryVectorBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
fieldSchema := &schemapb.FieldSchema{DataType: test.dType}
|
||||||
err := c.CheckValidDataType(fieldSchema)
|
err := c.CheckValidDataType("BINFLAT", fieldSchema)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
33
internal/util/indexparamcheck/bitmap_checker_test.go
Normal file
33
internal/util/indexparamcheck/bitmap_checker_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_BitmapIndexChecker(t *testing.T) {
|
||||||
|
c := newBITMAPChecker()
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||||
|
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexBitmap, &schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
|
||||||
|
}
|
||||||
@ -11,11 +11,11 @@ type BITMAPChecker struct {
|
|||||||
scalarIndexChecker
|
scalarIndexChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BITMAPChecker) CheckTrain(params map[string]string) error {
|
func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.scalarIndexChecker.CheckTrain(params)
|
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
return fmt.Errorf("create bitmap index on primary key not supported")
|
return fmt.Errorf("create bitmap index on primary key not supported")
|
||||||
}
|
}
|
||||||
@ -3,6 +3,8 @@ package indexparamcheck
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// diskannChecker checks if an diskann index can be built.
|
// diskannChecker checks if an diskann index can be built.
|
||||||
@ -10,8 +12,8 @@ type cagraChecker struct {
|
|||||||
floatVectorBaseChecker
|
floatVectorBaseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cagraChecker) CheckTrain(params map[string]string) error {
|
func (c *cagraChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
err := c.baseChecker.CheckTrain(params)
|
err := c.baseChecker.CheckTrain(dataType, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -54,7 +56,7 @@ func (c *cagraChecker) CheckTrain(params map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c cagraChecker) StaticCheck(params map[string]string) error {
|
func (c cagraChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.staticCheck(params)
|
return c.staticCheck(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6,6 +6,8 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,9 +103,13 @@ func Test_cagraChecker_CheckTrain(t *testing.T) {
|
|||||||
{p14, false},
|
{p14, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newCagraChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_CAGRA")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -20,6 +20,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IndexCheckerMgr interface {
|
type IndexCheckerMgr interface {
|
||||||
@ -34,36 +36,19 @@ type indexCheckerMgrImpl struct {
|
|||||||
|
|
||||||
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
|
func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, error) {
|
||||||
mgr.once.Do(mgr.registerIndexChecker)
|
mgr.once.Do(mgr.registerIndexChecker)
|
||||||
|
// Unify the vector index checker
|
||||||
|
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
|
||||||
|
return mgr.checkers[IndexVector], nil
|
||||||
|
}
|
||||||
adapter, ok := mgr.checkers[indexType]
|
adapter, ok := mgr.checkers[indexType]
|
||||||
if ok {
|
if ok {
|
||||||
return adapter, nil
|
return adapter, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("Can not find conf adapter: " + indexType)
|
return nil, errors.New("Can not find index: " + indexType + " , please check")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
|
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
|
||||||
mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker()
|
mgr.checkers[IndexVector] = newVecIndexChecker()
|
||||||
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
|
|
||||||
mgr.checkers[IndexRaftCagra] = newCagraChecker()
|
|
||||||
mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker()
|
|
||||||
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
|
|
||||||
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
|
|
||||||
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()
|
|
||||||
mgr.checkers[IndexScaNN] = newScaNNChecker()
|
|
||||||
mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker()
|
|
||||||
mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker()
|
|
||||||
mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker()
|
|
||||||
mgr.checkers[IndexHNSW] = newHnswChecker()
|
|
||||||
mgr.checkers[IndexDISKANN] = newDiskannChecker()
|
|
||||||
mgr.checkers[IndexSparseInverted] = newSparseInvertedIndexChecker()
|
|
||||||
mgr.checkers[IndexFaissHNSW] = newFloatVectorBaseChecker()
|
|
||||||
mgr.checkers[IndexFaissHNSWPQ] = newFloatVectorBaseChecker()
|
|
||||||
mgr.checkers[IndexFaissHNSWSQ] = newFloatVectorBaseChecker()
|
|
||||||
mgr.checkers[IndexFaissHNSWPRQ] = newFloatVectorBaseChecker()
|
|
||||||
// WAND doesn't have more index params than sparse inverted index, thus
|
|
||||||
// using the same checker.
|
|
||||||
mgr.checkers[IndexSparseWand] = newSparseInvertedIndexChecker()
|
|
||||||
mgr.checkers[IndexINVERTED] = newINVERTEDChecker()
|
mgr.checkers[IndexINVERTED] = newINVERTEDChecker()
|
||||||
mgr.checkers[IndexSTLSORT] = newSTLSORTChecker()
|
mgr.checkers[IndexSTLSORT] = newSTLSORTChecker()
|
||||||
mgr.checkers["Asceneding"] = newSTLSORTChecker()
|
mgr.checkers["Asceneding"] = newSTLSORTChecker()
|
||||||
@ -29,52 +29,52 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) {
|
|||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, nil, adapter)
|
assert.Equal(t, nil, adapter)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
adapter, err = adapterMgr.GetChecker("FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*flatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfBaseChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexScaNN)
|
adapter, err = adapterMgr.GetChecker("SCANN")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*scaNNChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
adapter, err = adapterMgr.GetChecker("IVF_PQ")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfPQChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfSQChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*binFlatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*binIVFFlatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
adapter, err = adapterMgr.GetChecker("HNSW")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*hnswChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,52 +89,52 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) {
|
|||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, nil, adapter)
|
assert.Equal(t, nil, adapter)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIDMap)
|
adapter, err = adapterMgr.GetChecker("FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*flatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfFlat)
|
adapter, err = adapterMgr.GetChecker("IVF_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfBaseChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexScaNN)
|
adapter, err = adapterMgr.GetChecker("SCANN")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*scaNNChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ)
|
adapter, err = adapterMgr.GetChecker("IVF_PQ")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfPQChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissIvfSQ8)
|
adapter, err = adapterMgr.GetChecker("IVF_SQ8")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*ivfSQChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIDMap)
|
adapter, err = adapterMgr.GetChecker("BIN_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*binFlatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexFaissBinIvfFlat)
|
adapter, err = adapterMgr.GetChecker("BIN_IVF_FLAT")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*binIVFFlatChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
adapter, err = adapterMgr.GetChecker(IndexHNSW)
|
adapter, err = adapterMgr.GetChecker("HNSW")
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, nil, adapter)
|
assert.NotEqual(t, nil, adapter)
|
||||||
_, ok = adapter.(*hnswChecker)
|
_, ok = adapter.(*vecIndexChecker)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ func TestConfAdapterMgrImpl_GetAdapter_multiple_threads(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
adapter, err := mgr.GetChecker(IndexHNSW)
|
adapter, err := mgr.GetChecker("HNSW")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, adapter)
|
assert.NotNil(t, adapter)
|
||||||
}()
|
}()
|
||||||
@ -65,7 +65,7 @@ var (
|
|||||||
CagraBuildAlgoTypes = []string{CargaBuildAlgoIVFPQ, CargaBuildAlgoNNDESCENT}
|
CagraBuildAlgoTypes = []string{CargaBuildAlgoIVFPQ, CargaBuildAlgoNNDESCENT}
|
||||||
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
|
||||||
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
|
||||||
SparseMetrics = []string{metric.IP} // const
|
SparseMetrics = []string{metric.IP, metric.BM25} // const
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -1,11 +1,13 @@
|
|||||||
package indexparamcheck
|
package indexparamcheck
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
// diskannChecker checks if an diskann index can be built.
|
// diskannChecker checks if an diskann index can be built.
|
||||||
type diskannChecker struct {
|
type diskannChecker struct {
|
||||||
floatVectorBaseChecker
|
floatVectorBaseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c diskannChecker) StaticCheck(params map[string]string) error {
|
func (c diskannChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.staticCheck(params)
|
return c.staticCheck(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -72,9 +73,10 @@ func Test_diskannChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newDiskannChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "DISKANN"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -142,9 +144,9 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newDiskannChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("DISKANN", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -1,10 +1,12 @@
|
|||||||
package indexparamcheck
|
package indexparamcheck
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
type flatChecker struct {
|
type flatChecker struct {
|
||||||
floatVectorBaseChecker
|
floatVectorBaseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c flatChecker) StaticCheck(m map[string]string) error {
|
func (c flatChecker) StaticCheck(dataType schemapb.DataType, m map[string]string) error {
|
||||||
return c.staticCheck(m)
|
return c.staticCheck(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6,6 +6,8 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,9 +54,10 @@ func Test_flatChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "FLAT"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -89,9 +92,10 @@ func Test_flatChecker_StaticCheck(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.StaticCheck(test.params)
|
test.params[common.IndexTypeKey] = "FLAT"
|
||||||
|
err := c.StaticCheck(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -20,22 +20,22 @@ func (c floatVectorBaseChecker) staticCheck(params map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
func (c floatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.baseChecker.CheckTrain(params); err != nil {
|
if err := c.baseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.staticCheck(params)
|
return c.staticCheck(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c floatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if !typeutil.IsDenseFloatVectorType(field.GetDataType()) {
|
if !typeutil.IsDenseFloatVectorType(field.GetDataType()) {
|
||||||
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
|
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,13 +63,13 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
dType: schemapb.DataType_BinaryVector,
|
dType: schemapb.DataType_BinaryVector,
|
||||||
errIsNil: false,
|
errIsNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newFloatVectorBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -12,7 +12,7 @@ type hnswChecker struct {
|
|||||||
baseChecker
|
baseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c hnswChecker) StaticCheck(params map[string]string) error {
|
func (c hnswChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
|
if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) {
|
||||||
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
|
return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction)
|
||||||
}
|
}
|
||||||
@ -25,21 +25,21 @@ func (c hnswChecker) StaticCheck(params map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c hnswChecker) CheckTrain(params map[string]string) error {
|
func (c hnswChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.StaticCheck(params); err != nil {
|
if err := c.StaticCheck(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.baseChecker.CheckTrain(params)
|
return c.baseChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c hnswChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if !typeutil.IsVectorType(field.GetDataType()) {
|
if !typeutil.IsVectorType(field.GetDataType()) {
|
||||||
return fmt.Errorf("can't build hnsw in not vector type")
|
return fmt.Errorf("can't build hnsw in not vector type")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
func (c hnswChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
if typeutil.IsDenseFloatVectorType(dType) {
|
if typeutil.IsDenseFloatVectorType(dType) {
|
||||||
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||||
} else if typeutil.IsSparseFloatVectorType(dType) {
|
} else if typeutil.IsSparseFloatVectorType(dType) {
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,13 +89,19 @@ func Test_hnswChecker_CheckTrain(t *testing.T) {
|
|||||||
{p3, true},
|
{p3, true},
|
||||||
{p4, true},
|
{p4, true},
|
||||||
{p5, true},
|
{p5, true},
|
||||||
{p6, false},
|
{p6, true},
|
||||||
{p7, false},
|
{p7, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newHnswChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "HNSW"
|
||||||
|
var err error
|
||||||
|
if CheckStrByValues(test.params, common.MetricTypeKey, BinaryVectorMetrics) {
|
||||||
|
err = c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||||
|
} else {
|
||||||
|
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
|
}
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -162,9 +169,9 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newHnswChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("HNSW", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -200,14 +207,14 @@ func Test_hnswChecker_SetDefaultMetricType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newHnswChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("HNSW")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
p := map[string]string{
|
p := map[string]string{
|
||||||
DIM: strconv.Itoa(128),
|
DIM: strconv.Itoa(128),
|
||||||
HNSWM: strconv.Itoa(16),
|
HNSWM: strconv.Itoa(16),
|
||||||
EFConstruction: strconv.Itoa(200),
|
EFConstruction: strconv.Itoa(200),
|
||||||
}
|
}
|
||||||
c.SetDefaultMetricTypeIfNotExist(p, test.dType)
|
c.SetDefaultMetricTypeIfNotExist(test.dType, p)
|
||||||
assert.Equal(t, p[Metric], test.metricType)
|
assert.Equal(t, p[Metric], test.metricType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
37
internal/util/indexparamcheck/hybrid_checker_test.go
Normal file
37
internal/util/indexparamcheck/hybrid_checker_test.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_HybridIndexChecker(t *testing.T) {
|
||||||
|
c := newHYBRIDChecker()
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{"bitmap_cardinality_limit": "100"}))
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
||||||
|
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
||||||
|
assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{}))
|
||||||
|
assert.Error(t, c.CheckTrain(schemapb.DataType_Float, map[string]string{"bitmap_cardinality_limit": "0"}))
|
||||||
|
assert.Error(t, c.CheckTrain(schemapb.DataType_Double, map[string]string{"bitmap_cardinality_limit": "2000"}))
|
||||||
|
}
|
||||||
@ -12,15 +12,15 @@ type HYBRIDChecker struct {
|
|||||||
scalarIndexChecker
|
scalarIndexChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HYBRIDChecker) CheckTrain(params map[string]string) error {
|
func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) {
|
if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) {
|
||||||
return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d",
|
return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d",
|
||||||
MaxBitmapCardinalityLimit)
|
MaxBitmapCardinalityLimit)
|
||||||
}
|
}
|
||||||
return c.scalarIndexChecker.CheckTrain(params)
|
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HYBRIDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
mainType := field.GetDataType()
|
mainType := field.GetDataType()
|
||||||
elemType := field.GetElementType()
|
elemType := field.GetElementType()
|
||||||
if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) &&
|
if !typeutil.IsBoolType(mainType) && !typeutil.IsIntegerType(mainType) &&
|
||||||
@ -21,8 +21,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type IndexChecker interface {
|
type IndexChecker interface {
|
||||||
CheckTrain(map[string]string) error
|
CheckTrain(schemapb.DataType, map[string]string) error
|
||||||
CheckValidDataType(field *schemapb.FieldSchema) error
|
CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error
|
||||||
SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType)
|
SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string)
|
||||||
StaticCheck(map[string]string) error
|
StaticCheck(schemapb.DataType, map[string]string) error
|
||||||
}
|
}
|
||||||
@ -15,6 +15,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,31 +24,7 @@ type IndexType = string
|
|||||||
|
|
||||||
// IndexType definitions
|
// IndexType definitions
|
||||||
const (
|
const (
|
||||||
// vector index
|
IndexVector IndexType = "VECINDEX"
|
||||||
IndexGpuBF IndexType = "GPU_BRUTE_FORCE"
|
|
||||||
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
|
|
||||||
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
|
|
||||||
IndexRaftCagra IndexType = "GPU_CAGRA"
|
|
||||||
IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE"
|
|
||||||
IndexFaissIDMap IndexType = "FLAT" // no index is built.
|
|
||||||
IndexFaissIvfFlat IndexType = "IVF_FLAT"
|
|
||||||
IndexFaissIvfPQ IndexType = "IVF_PQ"
|
|
||||||
IndexScaNN IndexType = "SCANN"
|
|
||||||
IndexFaissIvfSQ8 IndexType = "IVF_SQ8"
|
|
||||||
IndexFaissBinIDMap IndexType = "BIN_FLAT"
|
|
||||||
IndexFaissBinIvfFlat IndexType = "BIN_IVF_FLAT"
|
|
||||||
IndexHNSW IndexType = "HNSW"
|
|
||||||
IndexDISKANN IndexType = "DISKANN"
|
|
||||||
IndexSparseInverted IndexType = "SPARSE_INVERTED_INDEX"
|
|
||||||
IndexSparseWand IndexType = "SPARSE_WAND"
|
|
||||||
// For temporary use, will be removed in the future.
|
|
||||||
// 1. All Index related param check will be moved to Knowhere recently.
|
|
||||||
// 2. FAISS_HNSW_xxx will be rename to HNSW_xxx after QA test. We keep the original name for comparison purpose.
|
|
||||||
// TODO: @liliu-z @foxspy
|
|
||||||
IndexFaissHNSW IndexType = "FAISS_HNSW_FLAT"
|
|
||||||
IndexFaissHNSWPQ IndexType = "FAISS_HNSW_PQ"
|
|
||||||
IndexFaissHNSWSQ IndexType = "FAISS_HNSW_SQ"
|
|
||||||
IndexFaissHNSWPRQ IndexType = "FAISS_HNSW_PRQ"
|
|
||||||
|
|
||||||
// scalar index
|
// scalar index
|
||||||
IndexSTLSORT IndexType = "STL_SORT"
|
IndexSTLSORT IndexType = "STL_SORT"
|
||||||
@ -66,28 +43,12 @@ func IsScalarIndexType(indexType IndexType) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func IsGpuIndex(indexType IndexType) bool {
|
func IsGpuIndex(indexType IndexType) bool {
|
||||||
return indexType == IndexGpuBF ||
|
return vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(indexType)
|
||||||
indexType == IndexRaftIvfFlat ||
|
|
||||||
indexType == IndexRaftIvfPQ ||
|
|
||||||
indexType == IndexRaftCagra
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsVectorMmapIndex check if the vector index can be mmaped
|
// IsVectorMmapIndex check if the vector index can be mmaped
|
||||||
func IsVectorMmapIndex(indexType IndexType) bool {
|
func IsVectorMmapIndex(indexType IndexType) bool {
|
||||||
return indexType == IndexFaissIDMap ||
|
return vecindexmgr.GetVecIndexMgrInstance().IsMMapSupported(indexType)
|
||||||
indexType == IndexFaissIvfFlat ||
|
|
||||||
indexType == IndexFaissIvfPQ ||
|
|
||||||
indexType == IndexFaissIvfSQ8 ||
|
|
||||||
indexType == IndexFaissBinIDMap ||
|
|
||||||
indexType == IndexFaissBinIvfFlat ||
|
|
||||||
indexType == IndexHNSW ||
|
|
||||||
indexType == IndexFaissHNSW ||
|
|
||||||
indexType == IndexFaissHNSWPQ ||
|
|
||||||
indexType == IndexFaissHNSWSQ ||
|
|
||||||
indexType == IndexFaissHNSWPRQ ||
|
|
||||||
indexType == IndexScaNN ||
|
|
||||||
indexType == IndexSparseInverted ||
|
|
||||||
indexType == IndexSparseWand
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsOffsetCacheSupported(indexType IndexType) bool {
|
func IsOffsetCacheSupported(indexType IndexType) bool {
|
||||||
@ -95,7 +56,7 @@ func IsOffsetCacheSupported(indexType IndexType) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func IsDiskIndex(indexType IndexType) bool {
|
func IsDiskIndex(indexType IndexType) bool {
|
||||||
return indexType == IndexDISKANN
|
return vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsScalarMmapIndex(indexType IndexType) bool {
|
func IsScalarMmapIndex(indexType IndexType) bool {
|
||||||
@ -12,11 +12,11 @@ type INVERTEDChecker struct {
|
|||||||
scalarIndexChecker
|
scalarIndexChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *INVERTEDChecker) CheckTrain(params map[string]string) error {
|
func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.scalarIndexChecker.CheckTrain(params)
|
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
dType := field.GetDataType()
|
dType := field.GetDataType()
|
||||||
if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) &&
|
if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) &&
|
||||||
!typeutil.IsArrayType(dType) {
|
!typeutil.IsArrayType(dType) {
|
||||||
25
internal/util/indexparamcheck/inverted_checker_test.go
Normal file
25
internal/util/indexparamcheck/inverted_checker_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_INVERTEDIndexChecker(t *testing.T) {
|
||||||
|
c := newINVERTEDChecker()
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
|
||||||
|
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
|
||||||
|
}
|
||||||
28
internal/util/indexparamcheck/ivf_base_checker.go
Normal file
28
internal/util/indexparamcheck/ivf_base_checker.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
|
type ivfBaseChecker struct {
|
||||||
|
floatVectorBaseChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ivfBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
||||||
|
return errOutOfRange(NLIST, MinNList, MaxNList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip check number of rows
|
||||||
|
|
||||||
|
return c.floatVectorBaseChecker.staticCheck(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ivfBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
if err := c.StaticCheck(dataType, params); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.floatVectorBaseChecker.CheckTrain(dataType, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIVFBaseChecker() IndexChecker {
|
||||||
|
return &ivfBaseChecker{}
|
||||||
|
}
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,9 +71,10 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newIVFBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "IVF_FLAT"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -140,9 +142,9 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newIVFBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -3,6 +3,8 @@ package indexparamcheck
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ivfPQChecker checks if a IVF_PQ index can be built.
|
// ivfPQChecker checks if a IVF_PQ index can be built.
|
||||||
@ -11,8 +13,8 @@ type ivfPQChecker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||||
func (c *ivfPQChecker) CheckTrain(params map[string]string) error {
|
func (c *ivfPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,9 +142,11 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
|
||||||
c := newIVFPQChecker()
|
c := newIVFPQChecker()
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "IVF_PQ"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -211,9 +214,9 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newIVFPQChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -2,6 +2,8 @@ package indexparamcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ivfSQChecker checks if a IVF_SQ index can be built.
|
// ivfSQChecker checks if a IVF_SQ index can be built.
|
||||||
@ -22,11 +24,11 @@ func (c *ivfSQChecker) checkNBits(params map[string]string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckTrain returns true if the index can be built with the specific index parameters.
|
// CheckTrain returns true if the index can be built with the specific index parameters.
|
||||||
func (c *ivfSQChecker) CheckTrain(params map[string]string) error {
|
func (c *ivfSQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.checkNBits(params); err != nil {
|
if err := c.checkNBits(params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.ivfBaseChecker.CheckTrain(params)
|
return c.ivfBaseChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIVFSQChecker() IndexChecker {
|
func newIVFSQChecker() IndexChecker {
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,7 +79,6 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{validParams, true},
|
{validParams, true},
|
||||||
{validParamsWithNBits, true},
|
{validParamsWithNBits, true},
|
||||||
{paramsWithInvalidNBits, false},
|
|
||||||
{invalidIVFParamsMin(), false},
|
{invalidIVFParamsMin(), false},
|
||||||
{invalidIVFParamsMax(), false},
|
{invalidIVFParamsMax(), false},
|
||||||
{p1, true},
|
{p1, true},
|
||||||
@ -90,9 +90,10 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newIVFSQChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "IVF_SQ"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -160,9 +161,9 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newIVFSQChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("IVF_SQ8", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -1,14 +1,18 @@
|
|||||||
package indexparamcheck
|
package indexparamcheck
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
type raftBruteForceChecker struct {
|
type raftBruteForceChecker struct {
|
||||||
floatVectorBaseChecker
|
floatVectorBaseChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
// raftBrustForceChecker checks if a Brute_Force index can be built.
|
// raftBrustForceChecker checks if a Brute_Force index can be built.
|
||||||
func (c raftBruteForceChecker) CheckTrain(params map[string]string) error {
|
func (c raftBruteForceChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil {
|
if err := c.floatVectorBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
||||||
@ -6,6 +6,9 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,9 +55,14 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newRaftBruteForceChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_BRUTE_FORCE")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "GPU_BRUTE_FORCE"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -1,6 +1,10 @@
|
|||||||
package indexparamcheck
|
package indexparamcheck
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
|
// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
|
||||||
type raftIVFFlatChecker struct {
|
type raftIVFFlatChecker struct {
|
||||||
@ -8,8 +12,8 @@ type raftIVFFlatChecker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
|
// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
|
||||||
func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error {
|
func (c *raftIVFFlatChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
||||||
@ -7,6 +7,8 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,9 +86,14 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) {
|
|||||||
{p9, false},
|
{p9, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newRaftIVFFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "GPU_IVF_FLAT"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -154,9 +161,13 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newRaftIVFFlatChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_FLAT")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("GPU_IVF_FLAT", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -3,6 +3,8 @@ package indexparamcheck
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
|
// raftIVFPQChecker checks if a RAFT_IVF_PQ index can be built.
|
||||||
@ -11,8 +13,8 @@ type raftIVFPQChecker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
// CheckTrain checks if ivf-pq index can be built with the specific index parameters.
|
||||||
func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
|
func (c *raftIVFPQChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
if !CheckStrByValues(params, Metric, RaftMetrics) {
|
||||||
@ -7,6 +7,8 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -144,9 +146,14 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
|
|||||||
{p9, false},
|
{p9, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newRaftIVFPQChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "GPU_IVF_PQ"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -214,9 +221,13 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newRaftIVFPQChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("GPU_IVF_PQ")
|
||||||
|
if c == nil {
|
||||||
|
log.Error("can not get index checker instance, please enable GPU and rerun it")
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("GPU_IVF_PQ", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
11
internal/util/indexparamcheck/scalar_index_checker.go
Normal file
11
internal/util/indexparamcheck/scalar_index_checker.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
|
||||||
|
type scalarIndexChecker struct {
|
||||||
|
baseChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -4,9 +4,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckIndexValid(t *testing.T) {
|
func TestCheckIndexValid(t *testing.T) {
|
||||||
scalarIndexChecker := &scalarIndexChecker{}
|
scalarIndexChecker := &scalarIndexChecker{}
|
||||||
assert.NoError(t, scalarIndexChecker.CheckTrain(map[string]string{}))
|
assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, map[string]string{}))
|
||||||
}
|
}
|
||||||
@ -3,6 +3,8 @@ package indexparamcheck
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// scaNNChecker checks if a SCANN index can be built.
|
// scaNNChecker checks if a SCANN index can be built.
|
||||||
@ -11,8 +13,8 @@ type scaNNChecker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckTrain checks if SCANN index can be built with the specific index parameters.
|
// CheckTrain checks if SCANN index can be built with the specific index parameters.
|
||||||
func (c *scaNNChecker) CheckTrain(params map[string]string) error {
|
func (c *scaNNChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
|
if err := c.ivfBaseChecker.CheckTrain(dataType, params); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,9 +88,10 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) {
|
|||||||
{p7, false},
|
{p7, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newScaNNChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("SCANN")
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckTrain(test.params)
|
test.params[common.IndexTypeKey] = "SCANN"
|
||||||
|
err := c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -159,7 +161,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) {
|
|||||||
|
|
||||||
c := newScaNNChecker()
|
c := newScaNNChecker()
|
||||||
for _, test := range cases {
|
for _, test := range cases {
|
||||||
err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType})
|
err := c.CheckValidDataType("SCANN", &schemapb.FieldSchema{DataType: test.dType})
|
||||||
if test.errIsNil {
|
if test.errIsNil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -12,7 +12,7 @@ import (
|
|||||||
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
|
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
|
||||||
type sparseFloatVectorBaseChecker struct{}
|
type sparseFloatVectorBaseChecker struct{}
|
||||||
|
|
||||||
func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) error {
|
func (c sparseFloatVectorBaseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||||
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
|
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
|
||||||
}
|
}
|
||||||
@ -20,7 +20,7 @@ func (c sparseFloatVectorBaseChecker) StaticCheck(params map[string]string) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
func (c sparseFloatVectorBaseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
dropRatioBuildStr, exist := params[SparseDropRatioBuild]
|
dropRatioBuildStr, exist := params[SparseDropRatioBuild]
|
||||||
if exist {
|
if exist {
|
||||||
dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64)
|
dropRatioBuild, err := strconv.ParseFloat(dropRatioBuildStr, 64)
|
||||||
@ -48,14 +48,14 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c sparseFloatVectorBaseChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
|
if !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
|
||||||
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
|
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string, dType schemapb.DataType) {
|
func (c sparseFloatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
|
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6,84 +6,95 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
|
func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
|
||||||
validParams := map[string]string{
|
validParams := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
Metric: "IP",
|
Metric: "IP",
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidParams := map[string]string{
|
invalidParams := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
Metric: "L2",
|
Metric: "L2",
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newSparseFloatVectorBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||||
|
|
||||||
t.Run("valid metric", func(t *testing.T) {
|
t.Run("valid metric", func(t *testing.T) {
|
||||||
err := c.StaticCheck(validParams)
|
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, validParams)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid metric", func(t *testing.T) {
|
t.Run("invalid metric", func(t *testing.T) {
|
||||||
err := c.StaticCheck(invalidParams)
|
err := c.StaticCheck(schemapb.DataType_SparseFloatVector, invalidParams)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
|
func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
|
||||||
validParams := map[string]string{
|
validParams := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
|
Metric: "IP",
|
||||||
SparseDropRatioBuild: "0.5",
|
SparseDropRatioBuild: "0.5",
|
||||||
BM25K1: "1.5",
|
BM25K1: "1.5",
|
||||||
BM25B: "0.5",
|
BM25B: "0.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidDropRatio := map[string]string{
|
invalidDropRatio := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
|
Metric: "IP",
|
||||||
SparseDropRatioBuild: "1.5",
|
SparseDropRatioBuild: "1.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidBM25K1 := map[string]string{
|
invalidBM25K1 := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
|
Metric: "IP",
|
||||||
BM25K1: "3.5",
|
BM25K1: "3.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
invalidBM25B := map[string]string{
|
invalidBM25B := map[string]string{
|
||||||
|
common.IndexTypeKey: "SPARSE_INVERTED_INDEX",
|
||||||
|
Metric: "IP",
|
||||||
BM25B: "1.5",
|
BM25B: "1.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
c := newSparseFloatVectorBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||||
|
|
||||||
t.Run("valid params", func(t *testing.T) {
|
t.Run("valid params", func(t *testing.T) {
|
||||||
err := c.CheckTrain(validParams)
|
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, validParams)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid drop ratio", func(t *testing.T) {
|
t.Run("invalid drop ratio", func(t *testing.T) {
|
||||||
err := c.CheckTrain(invalidDropRatio)
|
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidDropRatio)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid BM25K1", func(t *testing.T) {
|
t.Run("invalid BM25K1", func(t *testing.T) {
|
||||||
err := c.CheckTrain(invalidBM25K1)
|
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25K1)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid BM25B", func(t *testing.T) {
|
t.Run("invalid BM25B", func(t *testing.T) {
|
||||||
err := c.CheckTrain(invalidBM25B)
|
err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25B)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
|
||||||
c := newSparseFloatVectorBaseChecker()
|
c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX")
|
||||||
|
|
||||||
t.Run("valid data type", func(t *testing.T) {
|
t.Run("valid data type", func(t *testing.T) {
|
||||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
|
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
|
||||||
err := c.CheckValidDataType(field)
|
err := c.CheckValidDataType("SPARSE_WAND", field)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid data type", func(t *testing.T) {
|
t.Run("invalid data type", func(t *testing.T) {
|
||||||
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
||||||
err := c.CheckValidDataType(field)
|
err := c.CheckValidDataType("SPARSE_WAND", field)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -12,11 +12,11 @@ type STLSORTChecker struct {
|
|||||||
scalarIndexChecker
|
scalarIndexChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STLSORTChecker) CheckTrain(params map[string]string) error {
|
func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.scalarIndexChecker.CheckTrain(params)
|
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if !typeutil.IsArithmetic(field.GetDataType()) {
|
if !typeutil.IsArithmetic(field.GetDataType()) {
|
||||||
return fmt.Errorf("STL_SORT are only supported on numeric field")
|
return fmt.Errorf("STL_SORT are only supported on numeric field")
|
||||||
}
|
}
|
||||||
22
internal/util/indexparamcheck/stl_sort_checker_test.go
Normal file
22
internal/util/indexparamcheck/stl_sort_checker_test.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_STLSORTIndexChecker(t *testing.T) {
|
||||||
|
c := newSTLSORTChecker()
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, map[string]string{}))
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||||
|
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||||
|
}
|
||||||
@ -12,11 +12,11 @@ type TRIEChecker struct {
|
|||||||
scalarIndexChecker
|
scalarIndexChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TRIEChecker) CheckTrain(params map[string]string) error {
|
func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
return c.scalarIndexChecker.CheckTrain(params)
|
return c.scalarIndexChecker.CheckTrain(dataType, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error {
|
func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
if !typeutil.IsStringType(field.GetDataType()) {
|
if !typeutil.IsStringType(field.GetDataType()) {
|
||||||
return fmt.Errorf("TRIE are only supported on varchar field")
|
return fmt.Errorf("TRIE are only supported on varchar field")
|
||||||
}
|
}
|
||||||
23
internal/util/indexparamcheck/trie_checker_test.go
Normal file
23
internal/util/indexparamcheck/trie_checker_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_TrieIndexChecker(t *testing.T) {
|
||||||
|
c := newTRIEChecker()
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, map[string]string{}))
|
||||||
|
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
||||||
|
assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
||||||
|
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
||||||
|
assert.Error(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
||||||
|
}
|
||||||
@ -20,7 +20,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckIntByRange check if the data corresponding to the key is in the range of [min, max].
|
// CheckIntByRange check if the data corresponding to the key is in the range of [min, max].
|
||||||
@ -69,3 +72,30 @@ func setDefaultIfNotExist(params map[string]string, key string, defaultValue str
|
|||||||
params[key] = defaultValue
|
params[key] = defaultValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckAutoIndexHelper(key string, m map[string]string, dtype schemapb.DataType) {
|
||||||
|
indexType, ok := m[common.IndexTypeKey]
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("%s invalid, index type not found", key))
|
||||||
|
}
|
||||||
|
|
||||||
|
checker, err := GetIndexCheckerMgrInstance().GetChecker(indexType)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checker.StaticCheck(dtype, m); err != nil {
|
||||||
|
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckAutoIndexConfig() {
|
||||||
|
autoIndexCfg := ¶mtable.Get().AutoIndexConfig
|
||||||
|
CheckAutoIndexHelper(autoIndexCfg.IndexParams.Key, autoIndexCfg.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
CheckAutoIndexHelper(autoIndexCfg.BinaryIndexParams.Key, autoIndexCfg.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||||
|
CheckAutoIndexHelper(autoIndexCfg.SparseIndexParams.Key, autoIndexCfg.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateParamTable() {
|
||||||
|
CheckAutoIndexConfig()
|
||||||
|
}
|
||||||
269
internal/util/indexparamcheck/utils_test.go
Normal file
269
internal/util/indexparamcheck/utils_test.go
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/config"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_CheckIntByRange(t *testing.T) {
|
||||||
|
params := map[string]string{
|
||||||
|
"1": strconv.Itoa(1),
|
||||||
|
"2": strconv.Itoa(2),
|
||||||
|
"3": strconv.Itoa(3),
|
||||||
|
"s1": "s1",
|
||||||
|
"s2": "s2",
|
||||||
|
"s3": "s3",
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
params map[string]string
|
||||||
|
key string
|
||||||
|
min int
|
||||||
|
max int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{params, "1", 0, 4, true},
|
||||||
|
{params, "2", 0, 4, true},
|
||||||
|
{params, "3", 0, 4, true},
|
||||||
|
{params, "1", 4, 5, false},
|
||||||
|
{params, "2", 4, 5, false},
|
||||||
|
{params, "3", 4, 5, false},
|
||||||
|
{params, "4", 0, 4, false},
|
||||||
|
{params, "5", 0, 4, false},
|
||||||
|
{params, "6", 0, 4, false},
|
||||||
|
{params, "s1", 0, 4, false},
|
||||||
|
{params, "s2", 0, 4, false},
|
||||||
|
{params, "s3", 0, 4, false},
|
||||||
|
{params, "s4", 0, 4, false},
|
||||||
|
{params, "s5", 0, 4, false},
|
||||||
|
{params, "s6", 0, 4, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range cases {
|
||||||
|
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
||||||
|
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CheckStrByValues(t *testing.T) {
|
||||||
|
params := map[string]string{
|
||||||
|
"1": strconv.Itoa(1),
|
||||||
|
"2": strconv.Itoa(2),
|
||||||
|
"3": strconv.Itoa(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
params map[string]string
|
||||||
|
key string
|
||||||
|
container []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{params, "1", []string{"1", "2", "3"}, true},
|
||||||
|
{params, "2", []string{"1", "2", "3"}, true},
|
||||||
|
{params, "3", []string{"1", "2", "3"}, true},
|
||||||
|
{params, "1", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "2", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "3", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "1", []string{}, false},
|
||||||
|
{params, "2", []string{}, false},
|
||||||
|
{params, "3", []string{}, false},
|
||||||
|
{params, "4", []string{"1", "2", "3"}, false},
|
||||||
|
{params, "5", []string{"1", "2", "3"}, false},
|
||||||
|
{params, "6", []string{"1", "2", "3"}, false},
|
||||||
|
{params, "4", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "5", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "6", []string{"4", "5", "6"}, false},
|
||||||
|
{params, "4", []string{}, false},
|
||||||
|
{params, "5", []string{}, false},
|
||||||
|
{params, "6", []string{}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range cases {
|
||||||
|
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
||||||
|
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CheckAutoIndex(t *testing.T) {
|
||||||
|
t.Run("index type not found", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported index type", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, hnsw", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.COSINE, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, binary vector", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
BinaryIndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.binary.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.BinaryIndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.HAMMING, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, sparse vector", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
SparseIndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.sparse.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.SparseIndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.IP, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.COSINE, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, ivf flat", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.COSINE, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, diskann", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.COSINE, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, bin flat", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.HAMMING, metricType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case, bin ivf flat", func(t *testing.T) {
|
||||||
|
mgr := config.NewManager()
|
||||||
|
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
|
||||||
|
p := ¶mtable.AutoIndexConfig{
|
||||||
|
IndexParams: paramtable.ParamItem{
|
||||||
|
Key: "autoIndex.params.build",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p.IndexParams.Init(mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
CheckAutoIndexHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector)
|
||||||
|
})
|
||||||
|
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
||||||
|
assert.True(t, exist)
|
||||||
|
assert.Equal(t, metric.HAMMING, metricType)
|
||||||
|
})
|
||||||
|
}
|
||||||
112
internal/util/indexparamcheck/vector_index_checker.go
Normal file
112
internal/util/indexparamcheck/vector_index_checker.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo pkg-config: milvus_core
|
||||||
|
|
||||||
|
#include <stdlib.h> // free
|
||||||
|
#include "segcore/vector_index_c.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type vecIndexChecker struct {
|
||||||
|
baseChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleCStatus deals with the error returned from CGO
|
||||||
|
func HandleCStatus(status *C.CStatus) error {
|
||||||
|
if status.error_code == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errorMsg := C.GoString(status.error_msg)
|
||||||
|
defer C.free(unsafe.Pointer(status.error_msg))
|
||||||
|
|
||||||
|
return fmt.Errorf("%s", errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
if typeutil.IsDenseFloatVectorType(dataType) {
|
||||||
|
if !CheckStrByValues(params, Metric, FloatVectorMetrics) {
|
||||||
|
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], FloatVectorMetrics)
|
||||||
|
}
|
||||||
|
} else if typeutil.IsSparseFloatVectorType(dataType) {
|
||||||
|
if !CheckStrByValues(params, Metric, SparseMetrics) {
|
||||||
|
return fmt.Errorf("metric type not found or not supported, supported: %v", SparseMetrics)
|
||||||
|
}
|
||||||
|
} else if typeutil.IsBinaryVectorType(dataType) {
|
||||||
|
if !CheckStrByValues(params, Metric, BinaryVectorMetrics) {
|
||||||
|
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], BinaryVectorMetrics)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
indexType, exist := params[common.IndexTypeKey]
|
||||||
|
|
||||||
|
if !exist {
|
||||||
|
return fmt.Errorf("no indexType is specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) {
|
||||||
|
return fmt.Errorf("indexType %s is not supported", indexType)
|
||||||
|
}
|
||||||
|
|
||||||
|
protoIndexParams := &indexcgopb.IndexParams{
|
||||||
|
Params: make([]*commonpb.KeyValuePair, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range params {
|
||||||
|
protoIndexParams.Params = append(protoIndexParams.Params, &commonpb.KeyValuePair{Key: key, Value: value})
|
||||||
|
}
|
||||||
|
|
||||||
|
indexParamsBlob, err := proto.Marshal(protoIndexParams)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal index params: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status C.CStatus
|
||||||
|
|
||||||
|
cIndexType := C.CString(indexType)
|
||||||
|
cDataType := uint32(dataType)
|
||||||
|
status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob)))
|
||||||
|
C.free(unsafe.Pointer(cIndexType))
|
||||||
|
|
||||||
|
return HandleCStatus(&status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||||
|
if err := c.StaticCheck(dataType, params); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.baseChecker.CheckTrain(dataType, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error {
|
||||||
|
if !typeutil.IsVectorType(field.GetDataType()) {
|
||||||
|
return fmt.Errorf("index %s only supports vector data type", indexType)
|
||||||
|
}
|
||||||
|
if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) {
|
||||||
|
return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c vecIndexChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
|
paramtable.SetDefaultMetricTypeIfNotExist(dType, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVecIndexChecker() IndexChecker {
|
||||||
|
return &vecIndexChecker{}
|
||||||
|
}
|
||||||
132
internal/util/indexparamcheck/vector_index_checker_test.go
Normal file
132
internal/util/indexparamcheck/vector_index_checker_test.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package indexparamcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVecIndexChecker_StaticCheck(t *testing.T) {
|
||||||
|
checker := newVecIndexChecker()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataType schemapb.DataType
|
||||||
|
params map[string]string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid IVF_FLAT index",
|
||||||
|
dataType: schemapb.DataType_FloatVector,
|
||||||
|
params: map[string]string{
|
||||||
|
"index_type": "IVF_FLAT",
|
||||||
|
"metric_type": "L2",
|
||||||
|
"nlist": "1024",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid index type",
|
||||||
|
dataType: schemapb.DataType_FloatVector,
|
||||||
|
params: map[string]string{
|
||||||
|
"index_type": "INVALID_INDEX",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing index type",
|
||||||
|
dataType: schemapb.DataType_FloatVector,
|
||||||
|
params: map[string]string{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := checker.StaticCheck(tt.dataType, tt.params)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVecIndexChecker_CheckValidDataType(t *testing.T) {
|
||||||
|
checker := newVecIndexChecker()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
indexType IndexType
|
||||||
|
field *schemapb.FieldSchema
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid float vector",
|
||||||
|
indexType: "IVF_FLAT",
|
||||||
|
field: &schemapb.FieldSchema{
|
||||||
|
DataType: schemapb.DataType_FloatVector,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid data type",
|
||||||
|
indexType: "IVF_FLAT",
|
||||||
|
field: &schemapb.FieldSchema{
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := checker.CheckValidDataType(tt.indexType, tt.field)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVecIndexChecker_SetDefaultMetricTypeIfNotExist(t *testing.T) {
|
||||||
|
checker := newVecIndexChecker()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataType schemapb.DataType
|
||||||
|
params map[string]string
|
||||||
|
expectedType string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Float vector",
|
||||||
|
dataType: schemapb.DataType_FloatVector,
|
||||||
|
params: map[string]string{},
|
||||||
|
expectedType: FloatVectorDefaultMetricType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Binary vector",
|
||||||
|
dataType: schemapb.DataType_BinaryVector,
|
||||||
|
params: map[string]string{},
|
||||||
|
expectedType: BinaryVectorDefaultMetricType,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Existing metric type",
|
||||||
|
dataType: schemapb.DataType_FloatVector,
|
||||||
|
params: map[string]string{"metric_type": "IP"},
|
||||||
|
expectedType: "IP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
checker.SetDefaultMetricTypeIfNotExist(tt.dataType, tt.params)
|
||||||
|
assert.Equal(t, tt.expectedType, tt.params["metric_type"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,17 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
type binFlatChecker struct {
|
|
||||||
binaryVectorBaseChecker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c binFlatChecker) CheckTrain(params map[string]string) error {
|
|
||||||
return c.binaryVectorBaseChecker.CheckTrain(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c binFlatChecker) StaticCheck(params map[string]string) error {
|
|
||||||
return c.staticCheck(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBinFlatChecker() IndexChecker {
|
|
||||||
return &binFlatChecker{}
|
|
||||||
}
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_BitmapIndexChecker(t *testing.T) {
|
|
||||||
c := newBITMAPChecker()
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
|
||||||
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double, IsPrimaryKey: true}))
|
|
||||||
}
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_HybridIndexChecker(t *testing.T) {
|
|
||||||
c := newHYBRIDChecker()
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"}))
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String}))
|
|
||||||
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double}))
|
|
||||||
assert.Error(t, c.CheckTrain(map[string]string{}))
|
|
||||||
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"}))
|
|
||||||
assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "2000"}))
|
|
||||||
}
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_INVERTEDIndexChecker(t *testing.T) {
|
|
||||||
c := newINVERTEDChecker()
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array}))
|
|
||||||
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}))
|
|
||||||
}
|
|
||||||
@ -1,26 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
type ivfBaseChecker struct {
|
|
||||||
floatVectorBaseChecker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c ivfBaseChecker) StaticCheck(params map[string]string) error {
|
|
||||||
if !CheckIntByRange(params, NLIST, MinNList, MaxNList) {
|
|
||||||
return errOutOfRange(NLIST, MinNList, MaxNList)
|
|
||||||
}
|
|
||||||
|
|
||||||
// skip check number of rows
|
|
||||||
|
|
||||||
return c.floatVectorBaseChecker.staticCheck(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c ivfBaseChecker) CheckTrain(params map[string]string) error {
|
|
||||||
if err := c.StaticCheck(params); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.floatVectorBaseChecker.CheckTrain(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIVFBaseChecker() IndexChecker {
|
|
||||||
return &ivfBaseChecker{}
|
|
||||||
}
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
type scalarIndexChecker struct {
|
|
||||||
baseChecker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c scalarIndexChecker) CheckTrain(params map[string]string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_STLSORTIndexChecker(t *testing.T) {
|
|
||||||
c := newSTLSORTChecker()
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
|
||||||
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
|
||||||
}
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_TrieIndexChecker(t *testing.T) {
|
|
||||||
c := newTRIEChecker()
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckTrain(map[string]string{}))
|
|
||||||
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}))
|
|
||||||
assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String}))
|
|
||||||
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float}))
|
|
||||||
assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON}))
|
|
||||||
}
|
|
||||||
@ -1,87 +0,0 @@
|
|||||||
package indexparamcheck
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_CheckIntByRange(t *testing.T) {
|
|
||||||
params := map[string]string{
|
|
||||||
"1": strconv.Itoa(1),
|
|
||||||
"2": strconv.Itoa(2),
|
|
||||||
"3": strconv.Itoa(3),
|
|
||||||
"s1": "s1",
|
|
||||||
"s2": "s2",
|
|
||||||
"s3": "s3",
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
params map[string]string
|
|
||||||
key string
|
|
||||||
min int
|
|
||||||
max int
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{params, "1", 0, 4, true},
|
|
||||||
{params, "2", 0, 4, true},
|
|
||||||
{params, "3", 0, 4, true},
|
|
||||||
{params, "1", 4, 5, false},
|
|
||||||
{params, "2", 4, 5, false},
|
|
||||||
{params, "3", 4, 5, false},
|
|
||||||
{params, "4", 0, 4, false},
|
|
||||||
{params, "5", 0, 4, false},
|
|
||||||
{params, "6", 0, 4, false},
|
|
||||||
{params, "s1", 0, 4, false},
|
|
||||||
{params, "s2", 0, 4, false},
|
|
||||||
{params, "s3", 0, 4, false},
|
|
||||||
{params, "s4", 0, 4, false},
|
|
||||||
{params, "s5", 0, 4, false},
|
|
||||||
{params, "s6", 0, 4, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range cases {
|
|
||||||
if got := CheckIntByRange(test.params, test.key, test.min, test.max); got != test.want {
|
|
||||||
t.Errorf("CheckIntByRange(%v, %v, %v, %v) = %v", test.params, test.key, test.min, test.max, test.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_CheckStrByValues(t *testing.T) {
|
|
||||||
params := map[string]string{
|
|
||||||
"1": strconv.Itoa(1),
|
|
||||||
"2": strconv.Itoa(2),
|
|
||||||
"3": strconv.Itoa(3),
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
params map[string]string
|
|
||||||
key string
|
|
||||||
container []string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{params, "1", []string{"1", "2", "3"}, true},
|
|
||||||
{params, "2", []string{"1", "2", "3"}, true},
|
|
||||||
{params, "3", []string{"1", "2", "3"}, true},
|
|
||||||
{params, "1", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "2", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "3", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "1", []string{}, false},
|
|
||||||
{params, "2", []string{}, false},
|
|
||||||
{params, "3", []string{}, false},
|
|
||||||
{params, "4", []string{"1", "2", "3"}, false},
|
|
||||||
{params, "5", []string{"1", "2", "3"}, false},
|
|
||||||
{params, "6", []string{"1", "2", "3"}, false},
|
|
||||||
{params, "4", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "5", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "6", []string{"4", "5", "6"}, false},
|
|
||||||
{params, "4", []string{}, false},
|
|
||||||
{params, "5", []string{}, false},
|
|
||||||
{params, "6", []string{}, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range cases {
|
|
||||||
if got := CheckStrByValues(test.params, test.key, test.container); got != test.want {
|
|
||||||
t.Errorf("CheckStrByValues(%v, %v, %v) = %v", test.params, test.key, test.container, test.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -24,12 +24,13 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/config"
|
"github.com/milvus-io/milvus/pkg/config"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// /////////////////////////////////////////////////////////////////////////////
|
// /////////////////////////////////////////////////////////////////////////////
|
||||||
// --- common ---
|
// --- common ---
|
||||||
type autoIndexConfig struct {
|
type AutoIndexConfig struct {
|
||||||
Enable ParamItem `refreshable:"true"`
|
Enable ParamItem `refreshable:"true"`
|
||||||
EnableOptimize ParamItem `refreshable:"true"`
|
EnableOptimize ParamItem `refreshable:"true"`
|
||||||
EnableResultLimitCheck ParamItem `refreshable:"true"`
|
EnableResultLimitCheck ParamItem `refreshable:"true"`
|
||||||
@ -60,7 +61,7 @@ const (
|
|||||||
DefaultBitmapCardinalityLimit = 100
|
DefaultBitmapCardinalityLimit = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *autoIndexConfig) init(base *BaseTable) {
|
func (p *AutoIndexConfig) init(base *BaseTable) {
|
||||||
p.Enable = ParamItem{
|
p.Enable = ParamItem{
|
||||||
Key: "autoIndex.enable",
|
Key: "autoIndex.enable",
|
||||||
Version: "2.2.0",
|
Version: "2.2.0",
|
||||||
@ -157,7 +158,7 @@ func (p *autoIndexConfig) init(base *BaseTable) {
|
|||||||
}
|
}
|
||||||
p.AutoIndexTuningConfig.Init(base.mgr)
|
p.AutoIndexTuningConfig.Init(base.mgr)
|
||||||
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricType(base.mgr)
|
p.SetDefaultMetricType(base.mgr)
|
||||||
|
|
||||||
p.ScalarAutoIndexEnable = ParamItem{
|
p.ScalarAutoIndexEnable = ParamItem{
|
||||||
Key: "scalarAutoIndex.enable",
|
Key: "scalarAutoIndex.enable",
|
||||||
@ -244,37 +245,47 @@ func (p *autoIndexConfig) init(base *BaseTable) {
|
|||||||
p.ScalarBoolIndexType.Init(base.mgr)
|
p.ScalarBoolIndexType.Init(base.mgr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricType(mgr *config.Manager) {
|
// SetDefaultMetricType The config check logic has been moved to internal package; only set defulat metric here
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
func (p *AutoIndexConfig) SetDefaultMetricType(mgr *config.Manager) {
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
p.SetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
||||||
|
p.SetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
|
func setDefaultIfNotExist(params map[string]string, key string, defaultValue string) {
|
||||||
|
_, exist := params[key]
|
||||||
|
if !exist {
|
||||||
|
params[key] = defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
FloatVectorDefaultMetricType = metric.COSINE
|
||||||
|
SparseFloatVectorDefaultMetricType = metric.IP
|
||||||
|
BinaryVectorDefaultMetricType = metric.HAMMING
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, params map[string]string) {
|
||||||
|
if typeutil.IsDenseFloatVectorType(dType) {
|
||||||
|
setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType)
|
||||||
|
} else if typeutil.IsSparseFloatVectorType(dType) {
|
||||||
|
setDefaultIfNotExist(params, common.MetricTypeKey, SparseFloatVectorDefaultMetricType)
|
||||||
|
} else if typeutil.IsBinaryVectorType(dType) {
|
||||||
|
setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AutoIndexConfig) SetDefaultMetricTypeHelper(key string, m map[string]string, dtype schemapb.DataType, mgr *config.Manager) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
panic(fmt.Sprintf("%s invalid, should be json format", key))
|
panic(fmt.Sprintf("%s invalid, should be json format", key))
|
||||||
}
|
}
|
||||||
|
|
||||||
indexType, ok := m[common.IndexTypeKey]
|
SetDefaultMetricTypeIfNotExist(dtype, m)
|
||||||
if !ok {
|
|
||||||
panic(fmt.Sprintf("%s invalid, index type not found", key))
|
|
||||||
}
|
|
||||||
|
|
||||||
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType))
|
|
||||||
}
|
|
||||||
|
|
||||||
checker.SetDefaultMetricTypeIfNotExist(m, dtype)
|
|
||||||
|
|
||||||
if err := checker.StaticCheck(m); err != nil {
|
|
||||||
panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
p.reset(key, m, mgr)
|
p.reset(key, m, mgr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *autoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
|
func (p *AutoIndexConfig) reset(key string, m map[string]string, mgr *config.Manager) {
|
||||||
ret, err := funcutil.MapToJSON(m)
|
ret, err := funcutil.MapToJSON(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("%s: convert to json failed, parameters invalid, error: %s", key, err.Error()))
|
panic(fmt.Sprintf("%s: convert to json failed, parameters invalid, error: %s", key, err.Error()))
|
||||||
|
|||||||
@ -26,7 +26,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/config"
|
"github.com/milvus-io/milvus/pkg/config"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -134,180 +133,16 @@ func Test_autoIndexConfig_panicIfNotValid(t *testing.T) {
|
|||||||
t.Run("not in json format", func(t *testing.T) {
|
t.Run("not in json format", func(t *testing.T) {
|
||||||
mgr := config.NewManager()
|
mgr := config.NewManager()
|
||||||
mgr.SetConfig("autoIndex.params.build", "not in json format")
|
mgr.SetConfig("autoIndex.params.build", "not in json format")
|
||||||
p := &autoIndexConfig{
|
p := &AutoIndexConfig{
|
||||||
IndexParams: ParamItem{
|
IndexParams: ParamItem{
|
||||||
Key: "autoIndex.params.build",
|
Key: "autoIndex.params.build",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
p.IndexParams.Init(mgr)
|
p.IndexParams.Init(mgr)
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
p.SetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("index type not found", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.Panics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("unsupported index type", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.Panics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, hnsw", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, binary vector", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.binary.build", `{"nlist": 1024, "index_type": "BIN_IVF_FLAT"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
BinaryIndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.binary.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.BinaryIndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.BinaryIndexParams.Key, p.BinaryIndexParams.GetAsJSONMap(), schemapb.DataType_BinaryVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.BinaryIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, sparse vector", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.sparse.build", `{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
SparseIndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.sparse.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.SparseIndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.SparseIndexParams.Key, p.SparseIndexParams.GetAsJSONMap(), schemapb.DataType_SparseFloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.SparseIndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.SparseFloatVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, ivf flat", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, diskann", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, bin flat", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal case, bin ivf flat", func(t *testing.T) {
|
|
||||||
mgr := config.NewManager()
|
|
||||||
mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`)
|
|
||||||
p := &autoIndexConfig{
|
|
||||||
IndexParams: ParamItem{
|
|
||||||
Key: "autoIndex.params.build",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.IndexParams.Init(mgr)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
p.panicIfNotValidAndSetDefaultMetricTypeHelper(p.IndexParams.Key, p.IndexParams.GetAsJSONMap(), schemapb.DataType_FloatVector, mgr)
|
|
||||||
})
|
|
||||||
metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey]
|
|
||||||
assert.True(t, exist)
|
|
||||||
assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScalarAutoIndexParams_build(t *testing.T) {
|
func TestScalarAutoIndexParams_build(t *testing.T) {
|
||||||
|
|||||||
@ -65,7 +65,7 @@ type ComponentParam struct {
|
|||||||
|
|
||||||
CommonCfg commonConfig
|
CommonCfg commonConfig
|
||||||
QuotaConfig quotaConfig
|
QuotaConfig quotaConfig
|
||||||
AutoIndexConfig autoIndexConfig
|
AutoIndexConfig AutoIndexConfig
|
||||||
GpuConfig gpuConfig
|
GpuConfig gpuConfig
|
||||||
TraceCfg traceConfig
|
TraceCfg traceConfig
|
||||||
|
|
||||||
|
|||||||
@ -607,7 +607,7 @@ func TestCreateIndexJsonField(t *testing.T) {
|
|||||||
// create vector index on json field
|
// create vector index on json field
|
||||||
idx := index.NewSCANNIndex(entity.L2, 8, false)
|
idx := index.NewSCANNIndex(entity.L2, 8, false)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idx).WithIndexName("json_index"))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idx).WithIndexName("json_index"))
|
||||||
common.CheckErr(t, err, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
|
common.CheckErr(t, err, false, "index SCANN only supports vector data type")
|
||||||
|
|
||||||
// create scalar index on json field
|
// create scalar index on json field
|
||||||
type scalarIndexError struct {
|
type scalarIndexError struct {
|
||||||
@ -653,7 +653,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) {
|
|||||||
if field.DataType == entity.FieldTypeArray {
|
if field.DataType == entity.FieldTypeArray {
|
||||||
// create vector index
|
// create vector index
|
||||||
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index"))
|
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index"))
|
||||||
common.CheckErr(t, err1, false, "data type should be FloatVector, Float16Vector or BFloat16Vector")
|
common.CheckErr(t, err1, false, "index SCANN only supports vector data type")
|
||||||
|
|
||||||
// create scalar index
|
// create scalar index
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx))
|
||||||
@ -840,11 +840,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
|
|||||||
for _, drb := range []float64{-0.3, 1.3} {
|
for _, drb := range []float64{-0.3, 1.3} {
|
||||||
idxInverted := index.NewSparseInvertedIndex(entity.IP, drb)
|
idxInverted := index.NewSparseInvertedIndex(entity.IP, drb)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
||||||
common.CheckErr(t, err, false, "must be in range [0, 1)")
|
common.CheckErr(t, err, false, "Out of range in json: param 'drop_ratio_build'")
|
||||||
|
|
||||||
idxWand := index.NewSparseWANDIndex(entity.IP, drb)
|
idxWand := index.NewSparseWANDIndex(entity.IP, drb)
|
||||||
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
_, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
||||||
common.CheckErr(t, err1, false, "must be in range [0, 1)")
|
common.CheckErr(t, err1, false, "Out of range in json: param 'drop_ratio_build'")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -944,20 +944,22 @@ func TestCreateVectorIndexScalarField(t *testing.T) {
|
|||||||
// create float vector index on scalar field
|
// create float vector index on scalar field
|
||||||
for _, idx := range hp.GenAllFloatIndex(entity.COSINE) {
|
for _, idx := range hp.GenAllFloatIndex(entity.COSINE) {
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx))
|
||||||
common.CheckErr(t, err, false, "can't build hnsw in not vector type",
|
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idx.IndexType())
|
||||||
"data type should be FloatVector, Float16Vector or BFloat16Vector")
|
common.CheckErr(t, err, false, expErrorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create binary vector index on scalar field
|
// create binary vector index on scalar field
|
||||||
for _, idxBinary := range []index.Index{index.NewBinFlatIndex(entity.IP), index.NewBinIvfFlatIndex(entity.COSINE, 64)} {
|
for _, idxBinary := range []index.Index{index.NewBinFlatIndex(entity.IP), index.NewBinIvfFlatIndex(entity.COSINE, 64)} {
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxBinary))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxBinary))
|
||||||
common.CheckErr(t, err, false, "binary vector is only supported")
|
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxBinary.IndexType())
|
||||||
|
common.CheckErr(t, err, false, expErrorMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create sparse vector index on scalar field
|
// create sparse vector index on scalar field
|
||||||
for _, idxSparse := range []index.Index{index.NewSparseInvertedIndex(entity.IP, 0.2), index.NewSparseWANDIndex(entity.IP, 0.3)} {
|
for _, idxSparse := range []index.Index{index.NewSparseInvertedIndex(entity.IP, 0.2), index.NewSparseWANDIndex(entity.IP, 0.3)} {
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxSparse))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxSparse))
|
||||||
common.CheckErr(t, err, false, "only sparse float vector is supported for the specified index")
|
expErrorMsg := fmt.Sprintf("index %s only supports vector data type", idxSparse.IndexType())
|
||||||
|
common.CheckErr(t, err, false, expErrorMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -972,7 +974,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||||||
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
|
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
|
||||||
|
|
||||||
// invalid IvfFlat nlist [1, 65536]
|
// invalid IvfFlat nlist [1, 65536]
|
||||||
errMsg := "nlist out of range: [1, 65536]"
|
errMsg := "Out of range in json: param 'nlist'"
|
||||||
for _, invalidNlist := range []int{0, -1, 65536 + 1} {
|
for _, invalidNlist := range []int{0, -1, 65536 + 1} {
|
||||||
// IvfFlat
|
// IvfFlat
|
||||||
idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist)
|
idxIvfFlat := index.NewIvfFlatIndex(entity.L2, invalidNlist)
|
||||||
@ -997,7 +999,7 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||||||
// IvfFlat
|
// IvfFlat
|
||||||
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits)
|
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 8, invalidNBits)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxIvfPq))
|
||||||
common.CheckErr(t, err, false, "parameter `nbits` out of range, expect range [1,64]")
|
common.CheckErr(t, err, false, "Out of range in json: param 'nbits'")
|
||||||
}
|
}
|
||||||
|
|
||||||
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)
|
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 7, 8)
|
||||||
@ -1009,13 +1011,13 @@ func TestCreateIndexInvalidParams(t *testing.T) {
|
|||||||
// IvfFlat
|
// IvfFlat
|
||||||
idxHnsw := index.NewHNSWIndex(entity.L2, invalidM, 96)
|
idxHnsw := index.NewHNSWIndex(entity.L2, invalidM, 96)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
||||||
common.CheckErr(t, err, false, "M out of range: [1, 2048]")
|
common.CheckErr(t, err, false, "Out of range in json: param 'M'")
|
||||||
}
|
}
|
||||||
for _, invalidEfConstruction := range []int{0, 2147483647 + 1} {
|
for _, invalidEfConstruction := range []int{0, 2147483647 + 1} {
|
||||||
// IvfFlat
|
// IvfFlat
|
||||||
idxHnsw := index.NewHNSWIndex(entity.L2, 8, invalidEfConstruction)
|
idxHnsw := index.NewHNSWIndex(entity.L2, 8, invalidEfConstruction)
|
||||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultFloatVecFieldName, idxHnsw))
|
||||||
common.CheckErr(t, err, false, "efConstruction out of range: [1, 2147483647]")
|
common.CheckErr(t, err, false, "Out of range in json: param 'efConstruction'", "integer value out of range, key: 'efConstruction'")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -33,10 +33,10 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/util/importutilv2"
|
"github.com/milvus-io/milvus/internal/util/importutilv2"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/tests/integration"
|
"github.com/milvus-io/milvus/tests/integration"
|
||||||
@ -66,7 +66,7 @@ func (s *BulkInsertSuite) SetupTest() {
|
|||||||
s.autoID = false
|
s.autoID = false
|
||||||
|
|
||||||
s.vecType = schemapb.DataType_FloatVector
|
s.vecType = schemapb.DataType_FloatVector
|
||||||
s.indexType = indexparamcheck.IndexHNSW
|
s.indexType = "HNSW"
|
||||||
s.metricType = metric.L2
|
s.metricType = metric.L2
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,29 +225,29 @@ func (s *BulkInsertSuite) TestMultiFileTypes() {
|
|||||||
s.fileType = fileType
|
s.fileType = fileType
|
||||||
|
|
||||||
s.vecType = schemapb.DataType_BinaryVector
|
s.vecType = schemapb.DataType_BinaryVector
|
||||||
s.indexType = indexparamcheck.IndexFaissBinIvfFlat
|
s.indexType = "BIN_IVF_FLAT"
|
||||||
s.metricType = metric.HAMMING
|
s.metricType = metric.HAMMING
|
||||||
s.run()
|
s.run()
|
||||||
|
|
||||||
s.vecType = schemapb.DataType_FloatVector
|
s.vecType = schemapb.DataType_FloatVector
|
||||||
s.indexType = indexparamcheck.IndexHNSW
|
s.indexType = "HNSW"
|
||||||
s.metricType = metric.L2
|
s.metricType = metric.L2
|
||||||
s.run()
|
s.run()
|
||||||
|
|
||||||
s.vecType = schemapb.DataType_Float16Vector
|
s.vecType = schemapb.DataType_Float16Vector
|
||||||
s.indexType = indexparamcheck.IndexHNSW
|
s.indexType = "HNSW"
|
||||||
s.metricType = metric.L2
|
s.metricType = metric.L2
|
||||||
s.run()
|
s.run()
|
||||||
|
|
||||||
s.vecType = schemapb.DataType_BFloat16Vector
|
s.vecType = schemapb.DataType_BFloat16Vector
|
||||||
s.indexType = indexparamcheck.IndexHNSW
|
s.indexType = "HNSW"
|
||||||
s.metricType = metric.L2
|
s.metricType = metric.L2
|
||||||
s.run()
|
s.run()
|
||||||
|
|
||||||
// TODO: not support numpy for SparseFloatVector by now
|
// TODO: not support numpy for SparseFloatVector by now
|
||||||
if fileType != importutilv2.Numpy {
|
if fileType != importutilv2.Numpy {
|
||||||
s.vecType = schemapb.DataType_SparseFloatVector
|
s.vecType = schemapb.DataType_SparseFloatVector
|
||||||
s.indexType = indexparamcheck.IndexSparseWand
|
s.indexType = "SPARSE_WAND"
|
||||||
s.metricType = metric.IP
|
s.metricType = metric.IP
|
||||||
s.run()
|
s.run()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,23 +26,22 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IndexRaftIvfFlat = indexparamcheck.IndexRaftIvfFlat
|
IndexRaftIvfFlat = "GPU_IVF_FLAT"
|
||||||
IndexRaftIvfPQ = indexparamcheck.IndexRaftIvfPQ
|
IndexRaftIvfPQ = "GPU_IVF_PQ"
|
||||||
IndexFaissIDMap = indexparamcheck.IndexFaissIDMap
|
IndexFaissIDMap = "FLAT"
|
||||||
IndexFaissIvfFlat = indexparamcheck.IndexFaissIvfFlat
|
IndexFaissIvfFlat = "IVF_FLAT"
|
||||||
IndexFaissIvfPQ = indexparamcheck.IndexFaissIvfPQ
|
IndexFaissIvfPQ = "IVF_PQ"
|
||||||
IndexScaNN = indexparamcheck.IndexScaNN
|
IndexScaNN = "SCANN"
|
||||||
IndexFaissIvfSQ8 = indexparamcheck.IndexFaissIvfSQ8
|
IndexFaissIvfSQ8 = "IVF_SQ8"
|
||||||
IndexFaissBinIDMap = indexparamcheck.IndexFaissBinIDMap
|
IndexFaissBinIDMap = "BIN_FLAT"
|
||||||
IndexFaissBinIvfFlat = indexparamcheck.IndexFaissBinIvfFlat
|
IndexFaissBinIvfFlat = "BIN_IVF_FLAT"
|
||||||
IndexHNSW = indexparamcheck.IndexHNSW
|
IndexHNSW = "HNSW"
|
||||||
IndexDISKANN = indexparamcheck.IndexDISKANN
|
IndexDISKANN = "DISKANN"
|
||||||
IndexSparseInvertedIndex = indexparamcheck.IndexSparseInverted
|
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
|
||||||
IndexSparseWand = indexparamcheck.IndexSparseWand
|
IndexSparseWand = "SPARSE_WAND"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
||||||
|
|||||||
@ -2103,6 +2103,8 @@ def gen_simple_index():
|
|||||||
continue
|
continue
|
||||||
elif ct.all_index_types[i] in ct.sparse_support:
|
elif ct.all_index_types[i] in ct.sparse_support:
|
||||||
continue
|
continue
|
||||||
|
elif ct.all_index_types[i] in ct.gpu_support:
|
||||||
|
continue
|
||||||
dic = {"index_type": ct.all_index_types[i], "metric_type": "L2"}
|
dic = {"index_type": ct.all_index_types[i], "metric_type": "L2"}
|
||||||
dic.update({"params": ct.default_all_indexes_params[i]})
|
dic.update({"params": ct.default_all_indexes_params[i]})
|
||||||
index_params.append(dic)
|
index_params.append(dic)
|
||||||
|
|||||||
@ -244,6 +244,7 @@ default_all_search_params_params = [{}, {"nprobe": 32}, {"nprobe": 32}, {"nprobe
|
|||||||
Handler_type = ["GRPC", "HTTP"]
|
Handler_type = ["GRPC", "HTTP"]
|
||||||
binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
|
binary_support = ["BIN_FLAT", "BIN_IVF_FLAT"]
|
||||||
sparse_support = ["SPARSE_INVERTED_INDEX", "SPARSE_WAND"]
|
sparse_support = ["SPARSE_INVERTED_INDEX", "SPARSE_WAND"]
|
||||||
|
gpu_support = ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
|
||||||
default_L0_metric = "COSINE"
|
default_L0_metric = "COSINE"
|
||||||
float_metrics = ["L2", "IP", "COSINE"]
|
float_metrics = ["L2", "IP", "COSINE"]
|
||||||
binary_metrics = ["JACCARD", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
binary_metrics = ["JACCARD", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
|
||||||
|
|||||||
@ -57,6 +57,8 @@ default_index_params = [
|
|||||||
def create_target_index(index, field_name):
|
def create_target_index(index, field_name):
|
||||||
index["field_name"] = field_name
|
index["field_name"] = field_name
|
||||||
|
|
||||||
|
def gpu_support():
|
||||||
|
return ["GPU_IVF_FLAT", "GPU_IVF_PQ"]
|
||||||
|
|
||||||
def binary_support():
|
def binary_support():
|
||||||
return ["BIN_FLAT", "BIN_IVF_FLAT"]
|
return ["BIN_FLAT", "BIN_IVF_FLAT"]
|
||||||
@ -764,6 +766,8 @@ def gen_simple_index():
|
|||||||
for i in range(len(all_index_types)):
|
for i in range(len(all_index_types)):
|
||||||
if all_index_types[i] in binary_support():
|
if all_index_types[i] in binary_support():
|
||||||
continue
|
continue
|
||||||
|
if all_index_types[i] in gpu_support():
|
||||||
|
continue
|
||||||
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
|
dic = {"index_type": all_index_types[i], "metric_type": "L2"}
|
||||||
dic.update({"params": default_index_params[i]})
|
dic.update({"params": default_index_params[i]})
|
||||||
index_params.append(dic)
|
index_params.append(dic)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user