From 4ed11d9775832f51cff3ff56aa57452fc2cd1423 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Wed, 25 Nov 2020 16:24:57 +0800 Subject: [PATCH] Reopen segment test assertion Signed-off-by: bigsheeper --- configs/advanced/proxy.yaml | 4 +- internal/core/cmake/FindClangTools.cmake | 22 ---- .../StructuredIndexFlat-inl.h | 3 +- .../StructuredIndexSort-inl.h | 3 +- .../src/query/visitors/ExecExprVisitor.cpp | 8 +- internal/proxy/grpc_service.go | 4 +- internal/proxy/paramtable.go | 12 ++ internal/proxy/task.go | 86 +++++++++++++ internal/proxy/validate_util.go | 118 ++++++++++++++++++ internal/proxy/validate_util_test.go | 84 +++++++++++++ internal/reader/segment_test.go | 4 +- scripts/core_build.sh | 2 +- 12 files changed, 312 insertions(+), 38 deletions(-) create mode 100644 internal/proxy/validate_util.go create mode 100644 internal/proxy/validate_util_test.go diff --git a/configs/advanced/proxy.yaml b/configs/advanced/proxy.yaml index 71cb3006c2..cc98ad4d85 100644 --- a/configs/advanced/proxy.yaml +++ b/configs/advanced/proxy.yaml @@ -25,4 +25,6 @@ proxy: pulsarBufSize: 1024 # pulsar chan buffer size timeTick: - bufSize: 512 \ No newline at end of file + bufSize: 512 + + maxNameLength: 255 diff --git a/internal/core/cmake/FindClangTools.cmake b/internal/core/cmake/FindClangTools.cmake index 64a33764d8..6541010075 100644 --- a/internal/core/cmake/FindClangTools.cmake +++ b/internal/core/cmake/FindClangTools.cmake @@ -33,17 +33,6 @@ find_program(CLANG_TIDY_BIN NAMES clang-tidy-10 - clang-tidy-9 - clang-tidy-8 - clang-tidy-7.0 - clang-tidy-7 - clang-tidy-6.0 - clang-tidy-5.0 - clang-tidy-4.0 - clang-tidy-3.9 - clang-tidy-3.8 - clang-tidy-3.7 - clang-tidy-3.6 clang-tidy PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin NO_DEFAULT_PATH @@ -90,17 +79,6 @@ else() find_program(CLANG_FORMAT_BIN NAMES clang-format-10 - clang-format-9 - clang-format-8 - clang-format-7.0 - clang-format-7 - clang-format-6.0 - clang-format-5.0 - clang-format-4.0 - clang-format-3.9 - clang-format-3.8 - clang-format-3.7 - clang-format-3.6 clang-format PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin NO_DEFAULT_PATH diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h index bee0bbdf98..761fbb8933 100644 --- a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexFlat-inl.h @@ -65,7 +65,8 @@ StructuredIndexFlat::NotIn(const size_t n, const T* values) { if (!is_built_) { build(); } - TargetBitmapPtr bitset = std::make_unique(data_.size(), true); + TargetBitmapPtr bitset = std::make_unique(data_.size()); + bitset->set(); for (size_t i = 0; i < n; ++i) { for (const auto& index : data_) { if (index->a_ == *(values + i)) { diff --git a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h index 9f7647a40b..0ac7a2b764 100644 --- a/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h +++ b/internal/core/src/index/knowhere/knowhere/index/structured_index_simple/StructuredIndexSort-inl.h @@ -120,7 +120,8 @@ StructuredIndexSort::NotIn(const size_t n, const T* values) { if (!is_built_) { build(); } - TargetBitmapPtr bitset = std::make_unique(data_.size(), true); + TargetBitmapPtr bitset = std::make_unique(data_.size()); + bitset->set(); for (size_t i = 0; i < n; ++i) { auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index f8d1a0ef0f..04bcd51fb3 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -130,13 +130,7 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { } case OpType::NotEqual: { - auto index_func = [val](Index* index) { - // Note: index->NotIn() is buggy, investigating - // this is a workaround - auto res = index->In(1, &val); - *res = ~std::move(*res); - return res; - }; + auto index_func = [val](Index* index) { return index->NotIn(1, &val); }; return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); }); } diff --git a/internal/proxy/grpc_service.go b/internal/proxy/grpc_service.go index ba0f92dc39..dbccf5ac44 100644 --- a/internal/proxy/grpc_service.go +++ b/internal/proxy/grpc_service.go @@ -82,9 +82,8 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc Schema: &commonpb.Blob{}, }, masterClient: p.masterClient, + schema: req, } - schemaBytes, _ := proto.Marshal(req) - cct.CreateCollectionRequest.Schema.Value = schemaBytes var cancel func() cct.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) defer cancel() @@ -125,6 +124,7 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu }, queryMsgStream: p.queryMsgStream, resultBuf: make(chan []*internalpb.SearchResult), + query: req, } var cancel func() qt.ctx, cancel = context.WithTimeout(ctx, reqTimeoutInterval) diff --git a/internal/proxy/paramtable.go b/internal/proxy/paramtable.go index 9010f78ef6..ba60abe634 100644 --- a/internal/proxy/paramtable.go +++ b/internal/proxy/paramtable.go @@ -430,3 +430,15 @@ func (pt *ParamTable) searchResultChannelNames() []string { } return channels } + +func (pt *ParamTable) MaxNameLength() int64 { + str, err := pt.Load("proxy.maxNameLength") + if err != nil { + panic(err) + } + maxNameLength, err := strconv.ParseInt(str, 10, 64) + if err != nil { + panic(err) + } + return maxNameLength +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 538fd05326..e2f83e50b9 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -5,11 +5,13 @@ import ( "errors" "log" + "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "github.com/zilliztech/milvus-distributed/internal/proto/masterpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) @@ -64,6 +66,15 @@ func (it *InsertTask) Type() internalpb.MsgType { } func (it *InsertTask) PreExecute() error { + collectionName := it.BaseInsertTask.CollectionName + if err := ValidateCollectionName(collectionName); err != nil { + return err + } + partitionTag := it.BaseInsertTask.PartitionTag + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -120,6 +131,7 @@ type CreateCollectionTask struct { masterClient masterpb.MasterClient result *commonpb.Status ctx context.Context + schema *schemapb.CollectionSchema } func (cct *CreateCollectionTask) ID() UniqueID { @@ -147,10 +159,24 @@ func (cct *CreateCollectionTask) SetTs(ts Timestamp) { } func (cct *CreateCollectionTask) PreExecute() error { + // validate collection name + if err := ValidateCollectionName(cct.schema.Name); err != nil { + return err + } + + // validate field name + for _, field := range cct.schema.Fields { + if err := ValidateFieldName(field.Name); err != nil { + return err + } + } + return nil } func (cct *CreateCollectionTask) Execute() error { + schemaBytes, _ := proto.Marshal(cct.schema) + cct.CreateCollectionRequest.Schema.Value = schemaBytes resp, err := cct.masterClient.CreateCollection(cct.ctx, &cct.CreateCollectionRequest) if err != nil { log.Printf("create collection failed, error= %v", err) @@ -201,6 +227,9 @@ func (dct *DropCollectionTask) SetTs(ts Timestamp) { } func (dct *DropCollectionTask) PreExecute() error { + if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -229,6 +258,7 @@ type QueryTask struct { resultBuf chan []*internalpb.SearchResult result *servicepb.QueryResult ctx context.Context + query *servicepb.Query } func (qt *QueryTask) ID() UniqueID { @@ -256,6 +286,15 @@ func (qt *QueryTask) SetTs(ts Timestamp) { } func (qt *QueryTask) PreExecute() error { + if err := ValidateCollectionName(qt.query.CollectionName); err != nil { + return err + } + + for _, tag := range qt.query.PartitionTags { + if err := ValidatePartitionTag(tag, false); err != nil { + return err + } + } return nil } @@ -367,6 +406,9 @@ func (hct *HasCollectionTask) SetTs(ts Timestamp) { } func (hct *HasCollectionTask) PreExecute() error { + if err := ValidateCollectionName(hct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -424,6 +466,9 @@ func (dct *DescribeCollectionTask) SetTs(ts Timestamp) { } func (dct *DescribeCollectionTask) PreExecute() error { + if err := ValidateCollectionName(dct.CollectionName.CollectionName); err != nil { + return err + } return nil } @@ -532,6 +577,16 @@ func (cpt *CreatePartitionTask) SetTs(ts Timestamp) { } func (cpt *CreatePartitionTask) PreExecute() error { + collName, partitionTag := cpt.PartitionName.CollectionName, cpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -577,6 +632,16 @@ func (dpt *DropPartitionTask) SetTs(ts Timestamp) { } func (dpt *DropPartitionTask) PreExecute() error { + collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } + return nil } @@ -622,6 +687,15 @@ func (hpt *HasPartitionTask) SetTs(ts Timestamp) { } func (hpt *HasPartitionTask) PreExecute() error { + collName, partitionTag := hpt.PartitionName.CollectionName, hpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } return nil } @@ -667,6 +741,15 @@ func (dpt *DescribePartitionTask) SetTs(ts Timestamp) { } func (dpt *DescribePartitionTask) PreExecute() error { + collName, partitionTag := dpt.PartitionName.CollectionName, dpt.PartitionName.Tag + + if err := ValidateCollectionName(collName); err != nil { + return err + } + + if err := ValidatePartitionTag(partitionTag, true); err != nil { + return err + } return nil } @@ -712,6 +795,9 @@ func (spt *ShowPartitionsTask) SetTs(ts Timestamp) { } func (spt *ShowPartitionsTask) PreExecute() error { + if err := ValidateCollectionName(spt.CollectionName.CollectionName); err != nil { + return err + } return nil } diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go new file mode 100644 index 0000000000..8049595f28 --- /dev/null +++ b/internal/proxy/validate_util.go @@ -0,0 +1,118 @@ +package proxy + +import ( + "strconv" + "strings" + + "github.com/zilliztech/milvus-distributed/internal/errors" +) + +func isAlpha(c uint8) bool { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') { + return false + } + return true +} + +func isNumber(c uint8) bool { + if c < '0' || c > '9' { + return false + } + return true +} + +func ValidateCollectionName(collName string) error { + collName = strings.TrimSpace(collName) + + if collName == "" { + return errors.New("Collection name should not be empty") + } + + invalidMsg := "Invalid collection name: " + collName + ". " + if int64(len(collName)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a collection name must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + firstChar := collName[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a collection name must be an underscore or letter." + return errors.New(msg) + } + + for i := 1; i < len(collName); i++ { + c := collName[i] + if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Collection name can only contain numbers, letters, dollars and underscores." + return errors.New(msg) + } + } + return nil +} + +func ValidatePartitionTag(partitionTag string, strictCheck bool) error { + partitionTag = strings.TrimSpace(partitionTag) + + invalidMsg := "Invalid partition tag: " + partitionTag + ". " + if partitionTag == "" { + msg := invalidMsg + "Partition tag should not be empty." + return errors.New(msg) + } + + if int64(len(partitionTag)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a partition tag must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + if strictCheck { + firstChar := partitionTag[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a partition tag must be an underscore or letter." + return errors.New(msg) + } + + tagSize := len(partitionTag) + for i := 1; i < tagSize; i++ { + c := partitionTag[i] + if c != '_' && c != '$' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Partition tag can only contain numbers, letters, dollars and underscores." + return errors.New(msg) + } + } + } + + return nil +} + +func ValidateFieldName(fieldName string) error { + fieldName = strings.TrimSpace(fieldName) + + if fieldName == "" { + return errors.New("Field name should not be empty") + } + + invalidMsg := "Invalid field name: " + fieldName + ". " + if int64(len(fieldName)) > Params.MaxNameLength() { + msg := invalidMsg + "The length of a field name must be less than " + + strconv.FormatInt(Params.MaxNameLength(), 10) + " characters." + return errors.New(msg) + } + + firstChar := fieldName[0] + if firstChar != '_' && !isAlpha(firstChar) { + msg := invalidMsg + "The first character of a field name must be an underscore or letter." + return errors.New(msg) + } + + fieldNameSize := len(fieldName) + for i := 1; i < fieldNameSize; i++ { + c := fieldName[i] + if c != '_' && !isAlpha(c) && !isNumber(c) { + msg := invalidMsg + "Field name cannot only contain numbers, letters, and underscores." + return errors.New(msg) + } + } + return nil +} diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go new file mode 100644 index 0000000000..336425f9de --- /dev/null +++ b/internal/proxy/validate_util_test.go @@ -0,0 +1,84 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateCollectionName(t *testing.T) { + Params.Init() + assert.Nil(t, ValidateCollectionName("abc")) + assert.Nil(t, ValidateCollectionName("_123abc")) + assert.Nil(t, ValidateCollectionName("abc123_$")) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidateCollectionName(name)) + } +} + +func TestValidatePartitionTag(t *testing.T) { + Params.Init() + assert.Nil(t, ValidatePartitionTag("abc", true)) + assert.Nil(t, ValidatePartitionTag("_123abc", true)) + assert.Nil(t, ValidatePartitionTag("abc123_$", true)) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidatePartitionTag(name, true)) + } + + assert.Nil(t, ValidatePartitionTag("ab cd", false)) + assert.Nil(t, ValidatePartitionTag("ab*", false)) +} + +func TestValidateFieldName(t *testing.T) { + Params.Init() + assert.Nil(t, ValidateFieldName("abc")) + assert.Nil(t, ValidateFieldName("_123abc")) + + longName := make([]byte, 256) + for i := 0; i < len(longName); i++ { + longName[i] = 'a' + } + invalidNames := []string{ + "123abc", + "$abc", + "_12 ac", + " ", + "", + string(longName), + "中文", + } + + for _, name := range invalidNames { + assert.NotNil(t, ValidateFieldName(name)) + } +} diff --git a/internal/reader/segment_test.go b/internal/reader/segment_test.go index aab0d9f78c..b970c18e30 100644 --- a/internal/reader/segment_test.go +++ b/internal/reader/segment_test.go @@ -3,7 +3,6 @@ package reader import ( "context" "encoding/binary" - "fmt" "log" "math" "testing" @@ -463,8 +462,7 @@ func TestSegment_segmentInsert(t *testing.T) { assert.GreaterOrEqual(t, offset, int64(0)) err := segment.segmentInsert(offset, &ids, ×tamps, &records) - //assert.NoError(t, err) - fmt.Println(err) + assert.NoError(t, err) deleteSegment(segment) deleteCollection(collection) } diff --git a/scripts/core_build.sh b/scripts/core_build.sh index 7edd18ed2e..8c9f730b31 100755 --- a/scripts/core_build.sh +++ b/scripts/core_build.sh @@ -138,7 +138,7 @@ ${CMAKE_CMD} if [[ ${RUN_CPPLINT} == "ON" ]]; then # cpplint check - make lint || true + make lint if [ $? -ne 0 ]; then echo "ERROR! cpplint check failed" exit 1