From 70710dee4724aa6446cf96bcdc0b7b01d124a30c Mon Sep 17 00:00:00 2001 From: neza2017 Date: Sat, 5 Dec 2020 16:11:03 +0800 Subject: [PATCH] Add parquet payload Signed-off-by: neza2017 --- .gitignore | 4 - .jenkins/modules/Build/Build.groovy | 4 +- .../Regression/PythonRegression.groovy | 5 +- Makefile | 8 +- build/docker/test/.env | 4 - cmd/proxy/proxy.go | 4 +- cmd/querynode/query_node.go | 18 +- configs/advanced/channel.yaml | 2 +- configs/config.yaml | 118 ++ configs/milvus.yaml | 9 +- docs/developer_guides/chap08_binlog.md | 223 --- internal/conf/conf.go | 150 ++ internal/conf/conf_test.go | 10 + internal/core/src/common/CMakeLists.txt | 13 +- internal/core/src/common/FieldMeta.h | 40 +- internal/core/src/common/Schema.cpp | 45 +- internal/core/src/common/Schema.h | 36 +- internal/core/src/common/Types.cpp | 45 - internal/core/src/common/Types.h | 40 - internal/core/src/query/BruteForceSearch.cpp | 81 +- internal/core/src/query/BruteForceSearch.h | 20 +- internal/core/src/query/Plan.cpp | 91 +- internal/core/src/query/PlanImpl.h | 4 + internal/core/src/query/Search.cpp | 92 +- internal/core/src/query/Search.h | 11 +- .../core/src/query/deprecated/GeneralQuery.h | 2 +- .../src/query/generated/ExecExprVisitor.h | 4 - .../src/query/generated/ExecPlanNodeVisitor.h | 2 +- .../src/query/visitors/ExecExprVisitor.cpp | 110 +- .../query/visitors/ExecPlanNodeVisitor.cpp | 19 +- .../query/visitors/ShowPlanNodeVisitor.cpp | 19 +- internal/core/src/segcore/Collection.cpp | 3 +- internal/core/src/segcore/ConcurrentVector.h | 12 +- internal/core/src/segcore/IndexingEntry.cpp | 6 +- internal/core/src/segcore/SegmentBase.h | 2 +- internal/core/src/segcore/SegmentNaive.cpp | 8 +- .../core/src/segcore/SegmentSmallIndex.cpp | 6 +- internal/core/src/segcore/segment_c.cpp | 4 +- internal/core/src/utils/EasyAssert.cpp | 5 +- internal/core/src/utils/EasyAssert.h | 5 - internal/core/unittest/CMakeLists.txt | 1 - internal/core/unittest/test_binary.cpp | 2 +- internal/core/unittest/test_common.cpp | 12 - .../core/unittest/test_concurrent_vector.cpp | 2 +- internal/core/unittest/test_expr.cpp | 135 +- internal/core/unittest/test_indexing.cpp | 102 +- internal/core/unittest/test_query.cpp | 135 +- internal/core/unittest/test_segcore.cpp | 6 +- internal/core/unittest/test_utils/DataGen.h | 41 +- internal/kv/mockkv/mock_etcd.go | 14 + internal/master/collection_task_test.go | 5 +- internal/master/config_task_test.go | 5 +- internal/master/global_allocator_test.go | 11 + internal/master/grpc_service_test.go | 5 +- internal/master/master_test.go | 57 +- internal/master/param_table.go | 232 ++- internal/master/param_table_test.go | 31 +- internal/master/partition_task_test.go | 8 +- internal/master/segment_manager_test.go | 10 +- internal/proxy/paramtable.go | 239 +++- internal/proxy/proxy.go | 7 +- internal/proxy/proxy_test.go | 21 +- internal/proxy/task_scheduler.go | 2 +- internal/querynode/collection_replica.go | 8 +- internal/querynode/collection_replica_test.go | 1239 +++++++++++++++-- internal/querynode/collection_test.go | 171 ++- internal/querynode/data_sync_service.go | 4 +- internal/querynode/data_sync_service_test.go | 91 +- internal/querynode/flow_graph_insert_node.go | 14 +- .../querynode/flow_graph_service_time_node.go | 6 +- internal/querynode/meta_service.go | 20 +- internal/querynode/meta_service_test.go | 474 +++++-- internal/querynode/param_table.go | 267 +++- internal/querynode/param_table_test.go | 73 +- internal/querynode/partition_test.go | 73 +- internal/querynode/plan_test.go | 101 +- internal/querynode/query_node.go | 45 +- internal/querynode/query_node_test.go | 99 +- internal/querynode/reader.go | 15 + internal/querynode/reduce_test.go | 56 +- internal/querynode/search_service.go | 57 +- internal/querynode/search_service_test.go | 154 +- internal/querynode/segment_test.go | 525 ++++++- internal/querynode/stats_service.go | 18 +- internal/querynode/stats_service_test.go | 179 ++- internal/querynode/tsafe.go | 4 +- internal/querynode/tsafe_test.go | 6 +- internal/storage/cwrapper/.gitignore | 1 - internal/storage/cwrapper/CMakeLists.txt | 33 +- internal/storage/cwrapper/ParquetWrapper.cpp | 15 +- internal/storage/cwrapper/ParquetWrapper.h | 3 +- internal/storage/cwrapper/build.sh | 58 - .../cmake/Modules/ConfigureArrow.cmake | 19 +- .../Templates/Arrow.CMakeLists.txt.cmake | 4 +- internal/storage/cwrapper/test/CMakeLists.txt | 2 - .../cwrapper/test/ParquetWrapperTest.cpp | 54 +- internal/storage/payload.go | 626 --------- internal/storage/payload_test.go | 426 ------ internal/util/paramtable/paramtable.go | 154 +- internal/util/paramtable/paramtable_test.go | 40 +- scripts/cwrapper_build.sh | 59 - scripts/run_cpp_unittest.sh | 10 - scripts/run_go_unittest.sh | 2 +- tests/python/utils.py | 4 +- 104 files changed, 4045 insertions(+), 3531 deletions(-) delete mode 100644 build/docker/test/.env create mode 100644 configs/config.yaml delete mode 100644 docs/developer_guides/chap08_binlog.md create mode 100644 internal/conf/conf.go create mode 100644 internal/conf/conf_test.go delete mode 100644 internal/core/src/common/Types.cpp delete mode 100644 internal/core/src/common/Types.h delete mode 100644 internal/core/unittest/test_common.cpp create mode 100644 internal/kv/mockkv/mock_etcd.go create mode 100644 internal/querynode/reader.go delete mode 100755 internal/storage/cwrapper/build.sh delete mode 100644 internal/storage/payload.go delete mode 100644 internal/storage/payload_test.go delete mode 100755 scripts/cwrapper_build.sh diff --git a/.gitignore b/.gitignore index f62f7bb3a8..27d518ecc3 100644 --- a/.gitignore +++ b/.gitignore @@ -55,7 +55,3 @@ cmake_build/ .DS_Store *.swp -cwrapper_build -**/.clangd/* -**/compile_commands.json -**/.lint diff --git a/.jenkins/modules/Build/Build.groovy b/.jenkins/modules/Build/Build.groovy index bf0958c7cd..998ed76e99 100644 --- a/.jenkins/modules/Build/Build.groovy +++ b/.jenkins/modules/Build/Build.groovy @@ -1,9 +1,9 @@ -timeout(time: 10, unit: 'MINUTES') { +timeout(time: 5, unit: 'MINUTES') { dir ("scripts") { sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"' } - sh '. ./scripts/before-install.sh && make install' + sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install' dir ("scripts") { withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { diff --git a/.jenkins/modules/Regression/PythonRegression.groovy b/.jenkins/modules/Regression/PythonRegression.groovy index d9bf7bf13f..8311dc1f4e 100644 --- a/.jenkins/modules/Regression/PythonRegression.groovy +++ b/.jenkins/modules/Regression/PythonRegression.groovy @@ -4,10 +4,7 @@ try { sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar' dir ('build/docker/deploy') { sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d master' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d proxy' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=1 -d querynode' - sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=2 -d querynode' + sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d' } dir ('build/docker/test') { diff --git a/Makefile b/Makefile index e30ba03587..e458293077 100644 --- a/Makefile +++ b/Makefile @@ -41,9 +41,9 @@ fmt: lint: @echo "Running $@ check" @GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean - @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/... - @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/... - @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/go/... + @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/... + @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/... + @GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/... ruleguard: @echo "Running $@ check" @@ -65,11 +65,9 @@ build-go: build-cpp: @(env bash $(PWD)/scripts/core_build.sh) - @(env bash $(PWD)/scripts/cwrapper_build.sh -t Release) build-cpp-with-unittest: @(env bash $(PWD)/scripts/core_build.sh -u) - @(env bash $(PWD)/scripts/cwrapper_build.sh -t Release) # Runs the tests. unittest: test-cpp test-go diff --git a/build/docker/test/.env b/build/docker/test/.env deleted file mode 100644 index f94e4c5d5d..0000000000 --- a/build/docker/test/.env +++ /dev/null @@ -1,4 +0,0 @@ -SOURCE_REPO=milvusdb -TARGET_REPO=milvusdb -SOURCE_TAG=latest -TARGET_TAG=latest \ No newline at end of file diff --git a/cmd/proxy/proxy.go b/cmd/proxy/proxy.go index 4273e5efee..bde5fe37bd 100644 --- a/cmd/proxy/proxy.go +++ b/cmd/proxy/proxy.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "log" "os" "os/signal" @@ -14,7 +13,8 @@ import ( func main() { proxy.Init() - fmt.Println("ProxyID is", proxy.Params.ProxyID()) + + // Creates server. ctx, cancel := context.WithCancel(context.Background()) svr, err := proxy.CreateProxy(ctx) if err != nil { diff --git a/cmd/querynode/query_node.go b/cmd/querynode/query_node.go index 0a5eee2837..fd15c32379 100644 --- a/cmd/querynode/query_node.go +++ b/cmd/querynode/query_node.go @@ -2,24 +2,18 @@ package main import ( "context" - "fmt" - "log" "os" "os/signal" "syscall" - "go.uber.org/zap" - "github.com/zilliztech/milvus-distributed/internal/querynode" ) func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() querynode.Init() - fmt.Println("QueryNodeID is", querynode.Params.QueryNodeID()) - // Creates server. - ctx, cancel := context.WithCancel(context.Background()) - svr := querynode.NewQueryNode(ctx, 0) sc := make(chan os.Signal, 1) signal.Notify(sc, @@ -34,14 +28,8 @@ func main() { cancel() }() - if err := svr.Start(); err != nil { - log.Fatal("run server failed", zap.Error(err)) - } + querynode.StartQueryNode(ctx) - <-ctx.Done() - log.Print("Got signal to exit", zap.String("signal", sig.String())) - - svr.Close() switch sig { case syscall.SIGTERM: exit(0) diff --git a/configs/advanced/channel.yaml b/configs/advanced/channel.yaml index ed62dd977e..071441e48f 100644 --- a/configs/advanced/channel.yaml +++ b/configs/advanced/channel.yaml @@ -32,7 +32,7 @@ msgChannel: # default channel range [0, 1) channelRange: - insert: [0, 2] + insert: [0, 1] delete: [0, 1] dataDefinition: [0,1] k2s: [0, 1] diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000000..f9b203855e --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,118 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. + +master: + address: localhost + port: 53100 + pulsarmoniterinterval: 1 + pulsartopic: "monitor-topic" + + proxyidlist: [1, 2] + proxyTimeSyncChannels: ["proxy1", "proxy2"] + proxyTimeSyncSubName: "proxy-topic" + softTimeTickBarrierInterval: 500 + + writeidlist: [3, 4] + writeTimeSyncChannels: ["write3", "write4"] + writeTimeSyncSubName: "write-topic" + + dmTimeSyncChannels: ["dm5", "dm6"] + k2sTimeSyncChannels: ["k2s7", "k2s8"] + + defaultSizePerRecord: 1024 + minimumAssignSize: 1048576 + segmentThreshold: 536870912 + segmentExpireDuration: 2000 + segmentThresholdFactor: 0.75 + querynodenum: 1 + writenodenum: 1 + statsChannels: "statistic" + +etcd: + address: localhost + port: 2379 + rootpath: by-dev + segthreshold: 10000 + +minio: + address: localhost + port: 9000 + accessKeyID: minioadmin + secretAccessKey: minioadmin + useSSL: false + +timesync: + interval: 400 + +storage: + driver: TIKV + address: localhost + port: 2379 + accesskey: + secretkey: + +pulsar: + authentication: false + user: user-default + token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY + address: localhost + port: 6650 + topicnum: 128 + +reader: + clientid: 0 + stopflag: -1 + readerqueuesize: 10000 + searchchansize: 10000 + key2segchansize: 10000 + topicstart: 0 + topicend: 128 + +writer: + clientid: 0 + stopflag: -2 + readerqueuesize: 10000 + searchbyidchansize: 10000 + parallelism: 100 + topicstart: 0 + topicend: 128 + bucket: "zilliz-hz" + +proxy: + timezone: UTC+8 + proxy_id: 1 + numReaderNodes: 2 + tsoSaveInterval: 200 + timeTickInterval: 200 + + pulsarTopics: + readerTopicPrefix: "milvusReader" + numReaderTopics: 2 + deleteTopic: "milvusDeleter" + queryTopic: "milvusQuery" + resultTopic: "milvusResult" + resultGroup: "milvusResultGroup" + timeTickTopic: "milvusTimeTick" + + network: + address: 0.0.0.0 + port: 19530 + + logs: + level: debug + trace.enable: true + path: /tmp/logs + max_log_file_size: 1024MB + log_rotate_num: 0 + + storage: + path: /var/lib/milvus + auto_flush_interval: 1 diff --git a/configs/milvus.yaml b/configs/milvus.yaml index f74a61fc75..b0eca24ad6 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -12,7 +12,7 @@ nodeID: # will be deprecated after v0.2 proxyIDList: [0] - queryNodeIDList: [1, 2] + queryNodeIDList: [2] writeNodeIDList: [3] etcd: @@ -23,13 +23,6 @@ etcd: kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath segThreshold: 10000 -minio: - address: localhost - port: 9000 - accessKeyID: minioadmin - secretAccessKey: minioadmin - useSSL: false - pulsar: address: localhost port: 6650 diff --git a/docs/developer_guides/chap08_binlog.md b/docs/developer_guides/chap08_binlog.md deleted file mode 100644 index fafb3d598a..0000000000 --- a/docs/developer_guides/chap08_binlog.md +++ /dev/null @@ -1,223 +0,0 @@ -## Binlog - -InsertBinlog、DeleteBinlog、DDLBinlog - -Binlog is stored in a columnar storage format, every column in schema should be stored in a individual file. Timestamp, schema, row id and primary key allocated by system are four special columns. Schema column records the DDL of the collection. - - - -## Event format - -Binlog file consists of 4 bytes magic number and a series of events. The first event must be descriptor event. - -### Event format - -``` -+=====================================+ -| event | timestamp 0 : 8 | create timestamp -| header +----------------------------+ -| | type_code 8 : 1 | event type code -| +----------------------------+ -| | server_id 9 : 4 | write node id -| +----------------------------+ -| | event_length 13 : 4 | length of event, including header and data -| +----------------------------+ -| | next_position 17 : 4 | offset of next event from the start of file -| +----------------------------+ -| | extra_headers 21 : x-21 | reserved part -+=====================================+ -| event | fixed part x : y | -| data +----------------------------+ -| | variable part | -+=====================================+ -``` - - - -### Descriptor Event format - -``` -+=====================================+ -| event | timestamp 0 : 8 | create timestamp -| header +----------------------------+ -| | type_code 8 : 1 | event type code -| +----------------------------+ -| | server_id 9 : 4 | write node id -| +----------------------------+ -| | event_length 13 : 4 | length of event, including header and data -| +----------------------------+ -| | next_position 17 : 4 | offset of next event from the start of file -+=====================================+ -| event | binlog_version 21 : 2 | binlog version -| data +----------------------------+ -| | server_version 23 : 8 | write node version -| +----------------------------+ -| | commit_id 31 : 8 | commit id of the programe in git -| +----------------------------+ -| | header_length 39 : 1 | header length of other event -| +----------------------------+ -| | collection_id 40 : 8 | collection id -| +----------------------------+ -| | partition_id 48 : 8 | partition id (schema column does not need) -| +----------------------------+ -| | segment_id 56 : 8 | segment id (schema column does not need) -| +----------------------------+ -| | start_timestamp 64 : 1 | minimum timestamp allocated by master of all events in this file -| +----------------------------+ -| | end_timestamp 65 : 1 | maximum timestamp allocated by master of all events in this file -| +----------------------------+ -| | post-header 66 : n | array of n bytes, one byte per event type that the server knows about -| | lengths for all | -| | event types | -+=====================================+ -``` - - - -### Type code - -``` -DESCRIPTOR_EVENT -INSERT_EVENT -DELETE_EVENT -CREATE_COLLECTION_EVENT -DROP_COLLECTION_EVENT -CREATE_PARTITION_EVENT -DROP_PARTITION_EVENT -``` - -DESCRIPTOR_EVENT must appear in all column files and always be the first event. - -INSERT_EVENT 可以出现在除DDL binlog文件外的其他列的binlog - -DELETE_EVENT 只能用于primary key 的binlog文件(目前只有按照primary key删除) - -CREATE_COLLECTION_EVENT、DROP_COLLECTION_EVENT、CREATE_PARTITION_EVENT、DROP_PARTITION_EVENT 只出现在DDL binlog文件 - - - -### Event data part - -``` -event data part - -INSERT_EVENT: -+================================================+ -| event | fixed | start_timestamp x : 8 | min timestamp in this event -| data | part +------------------------------+ -| | | end_timestamp x+8 : 8 | max timestamp in this event -| | +------------------------------+ -| | | reserved x+16 : y-x-16 | reserved part -| +--------+------------------------------+ -| |variable| parquet payloI ad | payload in parquet format -| |part | | -+================================================+ - -other events is similar with INSERT_EVENT - - -``` - - - - - - - -### Example - -Schema - -​ string | int | float(optional) | vector(512) - - - -Request: - -​ InsertRequest rows(1W) - -​ DeleteRequest pk=1 - -​ DropPartition partitionTag="abc" - - - -insert binlogs: - -​ rowid, pk, ts, string, int, float, vector 6 files - -​ all events are INSERT_EVENT -​ float column file contains some NULL value - -delete binlogs: - -​ pk, ts 2 files - -​ pk's events are DELETE_EVENT, ts's events are INSERT_EVENT - -DDL binlogs: - -​ ddl, ts - -​ ddl's event is DROP_PARTITION_EVENT, ts's event is INSERT_EVENT - - - -C++ interface - -```c++ -typedef void* CPayloadWriter -typedef struct CBuffer { - char* data; - int length; -} CBuffer - -typedef struct CStatus { - int error_code; - const char* error_msg; -} CStatus - - -// C++ interface -// writer -CPayloadWriter NewPayloadWriter(int columnType); -CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length); -CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length); -CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length); -CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length); -CStatus AddInt64ToPayload(CPayloadWriter payloadWriter, int64_t *values, int length); -CStatus AddFloatToPayload(CPayloadWriter payloadWriter, float *values, int length); -CStatus AddDoubleToPayload(CPayloadWriter payloadWriter, double *values, int length); -CStatus AddOneStringToPayload(CPayloadWriter payloadWriter, char *cstr, int str_size); -CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_t *values, int dimension, int length); -CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *values, int dimension, int length); - -CStatus FinishPayloadWriter(CPayloadWriter payloadWriter); -CBuffer GetPayloadBufferFromWriter(CPayloadWriter payloadWriter); -int GetPayloadLengthFromWriter(CPayloadWriter payloadWriter); -CStatus ReleasePayloadWriter(CPayloadWriter handler); - -// reader -CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size); -CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length); -CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length); -CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length); -CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length); -CStatus GetInt64FromPayload(CPayloadReader payloadReader, int64_t **values, int *length); -CStatus GetFloatFromPayload(CPayloadReader payloadReader, float **values, int *length); -CStatus GetDoubleFromPayload(CPayloadReader payloadReader, double **values, int *length); -CStatus GetOneStringFromPayload(CPayloadReader payloadReader, int idx, char **cstr, int *str_size); -CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, uint8_t **values, int *dimension, int *length); -CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, float **values, int *dimension, int *length); - -int GetPayloadLengthFromReader(CPayloadReader payloadReader); -CStatus ReleasePayloadReader(CPayloadReader payloadReader); - -``` - - - - - - - diff --git a/internal/conf/conf.go b/internal/conf/conf.go new file mode 100644 index 0000000000..ef9ada8c24 --- /dev/null +++ b/internal/conf/conf.go @@ -0,0 +1,150 @@ +package conf + +import ( + "io/ioutil" + "path" + "runtime" + + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + + storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type" + yaml "gopkg.in/yaml.v2" +) + +type UniqueID = typeutil.UniqueID + +// yaml.MapSlice + +type MasterConfig struct { + Address string + Port int32 + PulsarMonitorInterval int32 + PulsarTopic string + SegmentThreshold float32 + SegmentExpireDuration int64 + ProxyIDList []UniqueID + QueryNodeNum int + WriteNodeNum int +} + +type EtcdConfig struct { + Address string + Port int32 + Rootpath string + Segthreshold int64 +} + +type TimeSyncConfig struct { + Interval int32 +} + +type StorageConfig struct { + Driver storagetype.DriverType + Address string + Port int32 + Accesskey string + Secretkey string +} + +type PulsarConfig struct { + Authentication bool + User string + Token string + Address string + Port int32 + TopicNum int +} + +type ProxyConfig struct { + Timezone string `yaml:"timezone"` + ProxyID int `yaml:"proxy_id"` + NumReaderNodes int `yaml:"numReaderNodes"` + TosSaveInterval int `yaml:"tsoSaveInterval"` + TimeTickInterval int `yaml:"timeTickInterval"` + PulsarTopics struct { + ReaderTopicPrefix string `yaml:"readerTopicPrefix"` + NumReaderTopics int `yaml:"numReaderTopics"` + DeleteTopic string `yaml:"deleteTopic"` + QueryTopic string `yaml:"queryTopic"` + ResultTopic string `yaml:"resultTopic"` + ResultGroup string `yaml:"resultGroup"` + TimeTickTopic string `yaml:"timeTickTopic"` + } `yaml:"pulsarTopics"` + Network struct { + Address string `yaml:"address"` + Port int `yaml:"port"` + } `yaml:"network"` + Logs struct { + Level string `yaml:"level"` + TraceEnable bool `yaml:"trace.enable"` + Path string `yaml:"path"` + MaxLogFileSize string `yaml:"max_log_file_size"` + LogRotateNum int `yaml:"log_rotate_num"` + } `yaml:"logs"` + Storage struct { + Path string `yaml:"path"` + AutoFlushInterval int `yaml:"auto_flush_interval"` + } `yaml:"storage"` +} + +type Reader struct { + ClientID int + StopFlag int64 + ReaderQueueSize int + SearchChanSize int + Key2SegChanSize int + TopicStart int + TopicEnd int +} + +type Writer struct { + ClientID int + StopFlag int64 + ReaderQueueSize int + SearchByIDChanSize int + Parallelism int + TopicStart int + TopicEnd int + Bucket string +} + +type ServerConfig struct { + Master MasterConfig + Etcd EtcdConfig + Timesync TimeSyncConfig + Storage StorageConfig + Pulsar PulsarConfig + Writer Writer + Reader Reader + Proxy ProxyConfig +} + +var Config ServerConfig + +// func init() { +// load_config() +// } + +func getConfigsDir() string { + _, fpath, _, _ := runtime.Caller(0) + configPath := path.Dir(fpath) + "/../../configs/" + configPath = path.Dir(configPath) + return configPath +} + +func LoadConfigWithPath(yamlFilePath string) { + source, err := ioutil.ReadFile(yamlFilePath) + if err != nil { + panic(err) + } + err = yaml.Unmarshal(source, &Config) + if err != nil { + panic(err) + } + //fmt.Printf("Result: %v\n", Config) +} + +func LoadConfig(yamlFile string) { + filePath := path.Join(getConfigsDir(), yamlFile) + LoadConfigWithPath(filePath) +} diff --git a/internal/conf/conf_test.go b/internal/conf/conf_test.go new file mode 100644 index 0000000000..9255a701b8 --- /dev/null +++ b/internal/conf/conf_test.go @@ -0,0 +1,10 @@ +package conf + +import ( + "fmt" + "testing" +) + +func TestMain(m *testing.M) { + fmt.Printf("Result: %v\n", Config) +} diff --git a/internal/core/src/common/CMakeLists.txt b/internal/core/src/common/CMakeLists.txt index 6962797d58..f52a79e5b4 100644 --- a/internal/core/src/common/CMakeLists.txt +++ b/internal/core/src/common/CMakeLists.txt @@ -1,9 +1,8 @@ -set(COMMON_SRC - Schema.cpp - Types.cpp - ) +set(COMMON_SRC + Schema.cpp +) -add_library(milvus_common - ${COMMON_SRC} - ) +add_library(milvus_common + ${COMMON_SRC} +) target_link_libraries(milvus_common milvus_proto) diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 0adf92bd49..880d2ded47 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -10,13 +10,18 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #pragma once -#include "common/Types.h" +#include "utils/Types.h" #include "utils/Status.h" #include "utils/EasyAssert.h" + #include #include - namespace milvus { + +using Timestamp = uint64_t; // TODO: use TiKV-like timestamp +using engine::DataType; +using engine::FieldElementType; + inline int field_sizeof(DataType data_type, int dim = 1) { switch (data_type) { @@ -84,13 +89,7 @@ field_is_vector(DataType datatype) { struct FieldMeta { public: - FieldMeta(std::string_view name, DataType type) : name_(name), type_(type) { - Assert(!is_vector()); - } - - FieldMeta(std::string_view name, DataType type, int64_t dim, MetricType metric_type) - : name_(name), type_(type), vector_info_(VectorInfo{dim, metric_type}) { - Assert(is_vector()); + FieldMeta(std::string_view name, DataType type, int dim = 1) : name_(name), type_(type), dim_(dim) { } bool @@ -99,11 +98,14 @@ struct FieldMeta { return type_ == DataType::VECTOR_BINARY || type_ == DataType::VECTOR_FLOAT; } - int64_t + void + set_dim(int dim) { + dim_ = dim; + } + + int get_dim() const { - Assert(is_vector()); - Assert(vector_info_.has_value()); - return vector_info_->dim_; + return dim_; } const std::string& @@ -118,20 +120,12 @@ struct FieldMeta { int get_sizeof() const { - if (is_vector()) { - return field_sizeof(type_, get_dim()); - } else { - return field_sizeof(type_, 1); - } + return field_sizeof(type_, dim_); } private: - struct VectorInfo { - int64_t dim_; - MetricType metric_type_; - }; std::string name_; DataType type_ = DataType::NONE; - std::optional vector_info_; + int dim_ = 1; }; } // namespace milvus diff --git a/internal/core/src/common/Schema.cpp b/internal/core/src/common/Schema.cpp index 52e6912189..baa1a785d3 100644 --- a/internal/core/src/common/Schema.cpp +++ b/internal/core/src/common/Schema.cpp @@ -11,50 +11,35 @@ #include "common/Schema.h" #include -#include namespace milvus { - -using std::string; -static std::map -RepeatedKeyValToMap(const google::protobuf::RepeatedPtrField& kvs) { - std::map mapping; - for (auto& kv : kvs) { - AssertInfo(!mapping.count(kv.key()), "repeat key(" + kv.key() + ") in protobuf"); - mapping.emplace(kv.key(), kv.value()); - } - return mapping; -} - std::shared_ptr Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) { auto schema = std::make_shared(); schema->set_auto_id(schema_proto.autoid()); for (const milvus::proto::schema::FieldSchema& child : schema_proto.fields()) { + const auto& type_params = child.type_params(); + int64_t dim = -1; auto data_type = DataType(child.data_type()); + for (const auto& type_param : type_params) { + if (type_param.key() == "dim") { + dim = strtoll(type_param.value().c_str(), nullptr, 10); + } + } + + if (field_is_vector(data_type)) { + AssertInfo(dim != -1, "dim not found"); + } else { + AssertInfo(dim == 1 || dim == -1, "Invalid dim field. Should be 1 or not exists"); + dim = 1; + } if (child.is_primary_key()) { AssertInfo(!schema->primary_key_offset_opt_.has_value(), "repetitive primary key"); schema->primary_key_offset_opt_ = schema->size(); } - if (field_is_vector(data_type)) { - auto type_map = RepeatedKeyValToMap(child.type_params()); - auto index_map = RepeatedKeyValToMap(child.index_params()); - if (!index_map.count("metric_type")) { - auto default_metric_type = - data_type == DataType::VECTOR_FLOAT ? MetricType::METRIC_L2 : MetricType::METRIC_Jaccard; - index_map["metric_type"] = default_metric_type; - } - - AssertInfo(type_map.count("dim"), "dim not found"); - auto dim = boost::lexical_cast(type_map.at("dim")); - AssertInfo(index_map.count("metric_type"), "index not found"); - auto metric_type = GetMetricType(index_map.at("metric_type")); - schema->AddField(child.name(), data_type, dim, metric_type); - } else { - schema->AddField(child.name(), data_type); - } + schema->AddField(child.name(), data_type, dim); } return schema; } diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index 7e1f75af69..05710c7944 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -24,15 +24,19 @@ namespace milvus { class Schema { public: void - AddField(std::string_view field_name, DataType data_type) { - auto field_meta = FieldMeta(field_name, data_type); + AddField(std::string_view field_name, DataType data_type, int dim = 1) { + auto field_meta = FieldMeta(field_name, data_type, dim); this->AddField(std::move(field_meta)); } void - AddField(std::string_view field_name, DataType data_type, int64_t dim, MetricType metric_type) { - auto field_meta = FieldMeta(field_name, data_type, dim, metric_type); - this->AddField(std::move(field_meta)); + AddField(FieldMeta field_meta) { + auto offset = fields_.size(); + fields_.emplace_back(field_meta); + offsets_.emplace(field_meta.get_name(), offset); + auto field_sizeof = field_meta.get_sizeof(); + sizeof_infos_.push_back(field_sizeof); + total_sizeof_ += field_sizeof; } void @@ -40,6 +44,17 @@ class Schema { is_auto_id_ = is_auto_id; } + auto + begin() { + return fields_.begin(); + } + + auto + end() { + return fields_.end(); + } + + public: bool get_is_auto_id() const { return is_auto_id_; @@ -108,20 +123,11 @@ class Schema { static std::shared_ptr ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto); - void - AddField(FieldMeta&& field_meta) { - auto offset = fields_.size(); - fields_.emplace_back(field_meta); - offsets_.emplace(field_meta.get_name(), offset); - auto field_sizeof = field_meta.get_sizeof(); - sizeof_infos_.push_back(std::move(field_sizeof)); - total_sizeof_ += field_sizeof; - } - private: // this is where data holds std::vector fields_; + private: // a mapping for random access std::unordered_map offsets_; std::vector sizeof_infos_; diff --git a/internal/core/src/common/Types.cpp b/internal/core/src/common/Types.cpp deleted file mode 100644 index 0e4aa7d35c..0000000000 --- a/internal/core/src/common/Types.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -// -// Created by mike on 12/3/20. -// -#include "common/Types.h" -#include -#include "utils/EasyAssert.h" -#include -#include - -namespace milvus { - -using boost::algorithm::to_lower_copy; -namespace Metric = knowhere::Metric; -static auto map = [] { - boost::bimap mapping; - using pos = boost::bimap::value_type; - mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2)); - mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT)); - mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard)); - mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto)); - mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming)); - mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure)); - mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure)); - return mapping; -}(); - -MetricType -GetMetricType(const std::string& type_name) { - auto real_name = to_lower_copy(type_name); - AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")"); - return map.left.at(real_name); -} - -} // namespace milvus diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h deleted file mode 100644 index c46ee76c65..0000000000 --- a/internal/core/src/common/Types.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -#include "utils/Types.h" -#include -#include -#include -#include - -namespace milvus { -using Timestamp = uint64_t; // TODO: use TiKV-like timestamp -using engine::DataType; -using engine::FieldElementType; -using engine::QueryResult; -using MetricType = faiss::MetricType; - -faiss::MetricType -GetMetricType(const std::string& type); - -// NOTE: dependent type -// used at meta-template programming -template -constexpr std::true_type always_true{}; - -template -constexpr std::false_type always_false{}; - -template -using aligned_vector = std::vector>; - -} // namespace milvus diff --git a/internal/core/src/query/BruteForceSearch.cpp b/internal/core/src/query/BruteForceSearch.cpp index b8c334dffe..626fe54290 100644 --- a/internal/core/src/query/BruteForceSearch.cpp +++ b/internal/core/src/query/BruteForceSearch.cpp @@ -11,66 +11,20 @@ #include "BruteForceSearch.h" #include -#include -#include -#include namespace milvus::query { void -BinarySearchBruteForceNaive(MetricType metric_type, - int64_t code_size, - const uint8_t* binary_chunk, - int64_t chunk_size, - int64_t topk, - int64_t num_queries, - const uint8_t* query_data, - float* result_distances, - idx_t* result_labels, - faiss::ConcurrentBitsetPtr bitset) { - // THIS IS A NAIVE IMPLEMENTATION, ready for optimize - Assert(metric_type == faiss::METRIC_Jaccard); - Assert(code_size % 4 == 0); - - using T = std::tuple; - - for (int64_t q = 0; q < num_queries; ++q) { - auto query_ptr = query_data + code_size * q; - auto query = boost::dynamic_bitset(query_ptr, query_ptr + code_size); - std::vector max_heap(topk + 1, std::make_tuple(std::numeric_limits::max(), -1)); - - for (int64_t i = 0; i < chunk_size; ++i) { - auto element_ptr = binary_chunk + code_size * i; - auto element = boost::dynamic_bitset(element_ptr, element_ptr + code_size); - auto the_and = (query & element).count(); - auto the_or = (query | element).count(); - auto distance = the_or ? (float)(the_or - the_and) / the_or : 0; - if (distance < std::get<0>(max_heap[0])) { - max_heap[topk] = std::make_tuple(distance, i); - std::push_heap(max_heap.begin(), max_heap.end()); - std::pop_heap(max_heap.begin(), max_heap.end()); - } - } - std::sort(max_heap.begin(), max_heap.end()); - for (int k = 0; k < topk; ++k) { - auto info = max_heap[k]; - result_distances[k + q * topk] = std::get<0>(info); - result_labels[k + q * topk] = std::get<1>(info); - } - } -} - -void -BinarySearchBruteForceFast(MetricType metric_type, - int64_t code_size, - const uint8_t* binary_chunk, - int64_t chunk_size, - int64_t topk, - int64_t num_queries, - const uint8_t* query_data, - float* result_distances, - idx_t* result_labels, - faiss::ConcurrentBitsetPtr bitset) { +BinarySearchBruteForce(faiss::MetricType metric_type, + int64_t code_size, + const uint8_t* binary_chunk, + int64_t chunk_size, + int64_t topk, + int64_t num_queries, + const uint8_t* query_data, + float* result_distances, + idx_t* result_labels, + faiss::ConcurrentBitsetPtr bitset) { const idx_t block_size = segcore::DefaultElementPerChunk; bool use_heap = true; @@ -129,21 +83,6 @@ BinarySearchBruteForceFast(MetricType metric_type, for (int i = 0; i < num_queries; ++i) { result_distances[i] = static_cast(int_distances[i]); } - } else { - PanicInfo("Unsupported metric type"); } } - -void -BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset, - const uint8_t* binary_chunk, - int64_t chunk_size, - float* result_distances, - idx_t* result_labels, - faiss::ConcurrentBitsetPtr bitset) { - // TODO: refactor the internal function - BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size, - query_dataset.topk, query_dataset.num_queries, query_dataset.query_data, - result_distances, result_labels, bitset); -} } // namespace milvus::query diff --git a/internal/core/src/query/BruteForceSearch.h b/internal/core/src/query/BruteForceSearch.h index 4d9cba96df..1edc19e159 100644 --- a/internal/core/src/query/BruteForceSearch.h +++ b/internal/core/src/query/BruteForceSearch.h @@ -15,25 +15,15 @@ #include "common/Schema.h" namespace milvus::query { -using MetricType = faiss::MetricType; - -namespace dataset { -struct BinaryQueryDataset { - MetricType metric_type; - int64_t num_queries; - int64_t topk; - int64_t code_size; - const uint8_t* query_data; -}; - -} // namespace dataset - void -BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset, +BinarySearchBruteForce(faiss::MetricType metric_type, + int64_t code_size, const uint8_t* binary_chunk, int64_t chunk_size, + int64_t topk, + int64_t num_queries, + const uint8_t* query_data, float* result_distances, idx_t* result_labels, faiss::ConcurrentBitsetPtr bitset = nullptr); - } // namespace milvus::query diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 8e2a89cfb0..9993c05d82 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -26,25 +26,15 @@ static std::unique_ptr ParseVecNode(Plan* plan, const Json& out_body) { Assert(out_body.is_object()); // TODO add binary info + auto vec_node = std::make_unique(); Assert(out_body.size() == 1); auto iter = out_body.begin(); std::string field_name = iter.key(); - auto& vec_info = iter.value(); Assert(vec_info.is_object()); auto topK = vec_info["topk"]; AssertInfo(topK > 0, "topK must greater than 0"); AssertInfo(topK < 16384, "topK is too large"); - auto field_meta = plan->schema_.operator[](field_name); - - auto vec_node = [&]() -> std::unique_ptr { - auto data_type = field_meta.get_data_type(); - if (data_type == DataType::VECTOR_FLOAT) { - return std::make_unique(); - } else { - return std::make_unique(); - } - }(); vec_node->query_info_.topK_ = topK; vec_node->query_info_.metric_type_ = vec_info.at("metric_type"); vec_node->query_info_.search_params_ = vec_info.at("params"); @@ -70,6 +60,8 @@ to_lower(const std::string& raw) { return data; } +template +constexpr std::false_type always_false{}; template std::unique_ptr ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) { @@ -83,62 +75,31 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found"); auto op = RangeExpr::mapping_.at(op_name); - if constexpr (std::is_same_v) { - Assert(item.value().is_boolean()); - } else if constexpr (std::is_integral_v) { + if constexpr (std::is_integral_v) { Assert(item.value().is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(item.value().is_number()); } else { static_assert(always_false, "unsupported type"); - __builtin_unreachable(); } T value = item.value(); expr->conditions_.emplace_back(op, value); } - std::sort(expr->conditions_.begin(), expr->conditions_.end()); - return expr; -} - -template -std::unique_ptr -ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) { - auto expr = std::make_unique>(); - auto data_type = schema[field_name].get_data_type(); - Assert(body.is_array()); - expr->field_id_ = field_name; - expr->data_type_ = data_type; - for (auto& value : body) { - if constexpr (std::is_same_v) { - Assert(value.is_boolean()); - } else if constexpr (std::is_integral_v) { - Assert(value.is_number_integer()); - } else if constexpr (std::is_floating_point_v) { - Assert(value.is_number()); - } else { - static_assert(always_false, "unsupported type"); - __builtin_unreachable(); - } - T real_value = value; - expr->terms_.push_back(real_value); - } - std::sort(expr->terms_.begin(), expr->terms_.end()); return expr; } std::unique_ptr ParseRangeNode(const Schema& schema, const Json& out_body) { - Assert(out_body.is_object()); Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto field_name = out_iter.key(); auto body = out_iter.value(); auto data_type = schema[field_name].get_data_type(); Assert(!field_is_vector(data_type)); - switch (data_type) { case DataType::BOOL: { - return ParseRangeNodeImpl(schema, field_name, body); + PanicInfo("bool is not supported in Range node"); + // return ParseRangeNodeImpl(schema, field_name, body); } case DataType::INT8: return ParseRangeNodeImpl(schema, field_name, body); @@ -157,42 +118,6 @@ ParseRangeNode(const Schema& schema, const Json& out_body) { } } -static std::unique_ptr -ParseTermNode(const Schema& schema, const Json& out_body) { - Assert(out_body.size() == 1); - auto out_iter = out_body.begin(); - auto field_name = out_iter.key(); - auto body = out_iter.value(); - auto data_type = schema[field_name].get_data_type(); - Assert(!field_is_vector(data_type)); - switch (data_type) { - case DataType::BOOL: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::INT8: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::INT16: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::INT32: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::INT64: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::FLOAT: { - return ParseTermNodeImpl(schema, field_name, body); - } - case DataType::DOUBLE: { - return ParseTermNodeImpl(schema, field_name, body); - } - default: { - PanicInfo("unsupported data_type"); - } - } -} - static std::unique_ptr CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) { auto plan = std::make_unique(schema); @@ -208,10 +133,6 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) { if (pack.contains("vector")) { auto& out_body = pack.at("vector"); plan->plan_node_ = ParseVecNode(plan.get(), out_body); - } else if (pack.contains("term")) { - AssertInfo(!predicate, "unsupported complex DSL"); - auto& out_body = pack.at("term"); - predicate = ParseTermNode(schema, out_body); } else if (pack.contains("range")) { AssertInfo(!predicate, "unsupported complex DSL"); auto& out_body = pack.at("range"); diff --git a/internal/core/src/query/PlanImpl.h b/internal/core/src/query/PlanImpl.h index 1d33eb18ae..453959e530 100644 --- a/internal/core/src/query/PlanImpl.h +++ b/internal/core/src/query/PlanImpl.h @@ -20,6 +20,7 @@ #include #include #include +#include namespace milvus::query { using Json = nlohmann::json; @@ -38,6 +39,9 @@ struct Plan { // TODO: add move extra info }; +template +using aligned_vector = std::vector>; + struct Placeholder { // milvus::proto::service::PlaceholderGroup group_; std::string tag_; diff --git a/internal/core/src/query/Search.cpp b/internal/core/src/query/Search.cpp index 084bb8e5de..b73014a86a 100644 --- a/internal/core/src/query/Search.cpp +++ b/internal/core/src/query/Search.cpp @@ -16,7 +16,6 @@ #include #include "utils/tools.h" -#include "query/BruteForceSearch.h" namespace milvus::query { using segcore::DefaultElementPerChunk; @@ -27,7 +26,7 @@ create_bitmap_view(std::optional bitmaps_opt, int64_t chunk return nullptr; } auto& bitmaps = *bitmaps_opt.value(); - auto src_vec = ~bitmaps.at(chunk_id); + auto& src_vec = bitmaps.at(chunk_id); auto dst = std::make_shared(src_vec.size()); auto iter = reinterpret_cast(dst->mutable_data()); @@ -42,7 +41,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, int64_t num_queries, Timestamp timestamp, std::optional bitmaps_opt, - QueryResult& results) { + segcore::QueryResult& results) { auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); auto& record = segment.get_insert_record(); @@ -132,92 +131,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, } results.result_ids_ = std::move(final_uids); // TODO: deprecated code end + return Status::OK(); } - -Status -BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, - const query::QueryInfo& info, - const uint8_t* query_data, - int64_t num_queries, - Timestamp timestamp, - std::optional bitmaps_opt, - QueryResult& results) { - auto& schema = segment.get_schema(); - auto& indexing_record = segment.get_indexing_record(); - auto& record = segment.get_insert_record(); - // step 1: binary search to find the barrier of the snapshot - auto ins_barrier = get_barrier(record, timestamp); - auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk); - auto metric_type = GetMetricType(info.metric_type_); - // auto del_barrier = get_barrier(deleted_record_, timestamp); - -#if 0 - auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); - Assert(bitmap_holder); - auto bitmap = bitmap_holder->bitmap_ptr; -#endif - - // step 2.1: get meta - // step 2.2: get which vector field to search - auto vecfield_offset_opt = schema.get_offset(info.field_id_); - Assert(vecfield_offset_opt.has_value()); - auto vecfield_offset = vecfield_offset_opt.value(); - auto& field = schema[vecfield_offset]; - - Assert(field.get_data_type() == DataType::VECTOR_BINARY); - auto dim = field.get_dim(); - auto code_size = dim / 8; - auto topK = info.topK_; - auto total_count = topK * num_queries; - - // step 3: small indexing search - std::vector final_uids(total_count, -1); - std::vector final_dis(total_count, std::numeric_limits::max()); - query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data}; - - using segcore::BinaryVector; - auto vec_ptr = record.get_entity(vecfield_offset); - - auto max_indexed_id = 0; - // step 4: brute force search where small indexing is unavailable - for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) { - std::vector buf_uids(total_count, -1); - std::vector buf_dis(total_count, std::numeric_limits::max()); - - auto& chunk = vec_ptr->get_chunk(chunk_id); - auto nsize = - chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk; - - auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); - BinarySearchBruteForce(query_dataset, chunk.data(), nsize, buf_dis.data(), buf_uids.data(), bitmap_view); - - // convert chunk uid to segment uid - for (auto& x : buf_uids) { - if (x != -1) { - x += chunk_id * DefaultElementPerChunk; - } - } - - segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); - } - - results.result_distances_ = std::move(final_dis); - results.internal_seg_offsets_ = std::move(final_uids); - results.topK_ = topK; - results.num_queries_ = num_queries; - - // TODO: deprecated code begin - final_uids = results.internal_seg_offsets_; - for (auto& id : final_uids) { - if (id == -1) { - continue; - } - id = record.uids_[id]; - } - results.result_ids_ = std::move(final_uids); - // TODO: deprecated code end - return Status::OK(); -} - } // namespace milvus::query diff --git a/internal/core/src/query/Search.h b/internal/core/src/query/Search.h index a5334038ac..8426f47227 100644 --- a/internal/core/src/query/Search.h +++ b/internal/core/src/query/Search.h @@ -27,14 +27,5 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, int64_t num_queries, Timestamp timestamp, std::optional bitmap_opt, - QueryResult& results); - -Status -BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment, - const query::QueryInfo& info, - const uint8_t* query_data, - int64_t num_queries, - Timestamp timestamp, - std::optional bitmaps_opt, - QueryResult& results); + segcore::QueryResult& results); } // namespace milvus::query diff --git a/internal/core/src/query/deprecated/GeneralQuery.h b/internal/core/src/query/deprecated/GeneralQuery.h index ad5421ddc4..c54e2aa2dd 100644 --- a/internal/core/src/query/deprecated/GeneralQuery.h +++ b/internal/core/src/query/deprecated/GeneralQuery.h @@ -18,7 +18,7 @@ #include #include -#include "common/Types.h" +#include "utils/Types.h" #include "utils/Json.h" namespace milvus { diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index 250d68a6e5..86464b70b0 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -58,10 +58,6 @@ class ExecExprVisitor : ExprVisitor { auto ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType; - template - auto - ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType; - private: segcore::SegmentSmallIndex& segment_; std::optional ret_; diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index 0eb33384d7..29e929bd87 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -28,7 +28,7 @@ class ExecPlanNodeVisitor : PlanNodeVisitor { visit(BinaryVectorANNS& node) override; public: - using RetType = QueryResult; + using RetType = segcore::QueryResult; ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group) : segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) { } diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index d5eb3a026b..83e5782c91 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -46,10 +46,6 @@ class ExecExprVisitor : ExprVisitor { auto ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType; - template - auto - ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType; - private: segcore::SegmentSmallIndex& segment_; std::optional ret_; @@ -67,6 +63,11 @@ ExecExprVisitor::visit(BoolBinaryExpr& expr) { PanicInfo("unimplemented"); } +void +ExecExprVisitor::visit(TermExpr& expr) { + PanicInfo("unimplemented"); +} + template auto ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, IndexFunc index_func, ElementFunc element_func) @@ -83,17 +84,17 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl& expr, IndexFunc index_fu auto& indexing_record = segment_.get_indexing_record(); const segcore::ScalarIndexingEntry& entry = indexing_record.get_scalar_entry(field_offset); - RetType results(vec.num_chunk()); + RetType results(vec.chunk_size()); auto indexing_barrier = indexing_record.get_finished_ack(); for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) { auto& result = results[chunk_id]; auto indexing = entry.get_indexing(chunk_id); auto data = index_func(indexing); - result = std::move(*data); + result = ~std::move(*data); Assert(result.size() == segcore::DefaultElementPerChunk); } - for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) { + for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) { auto& result = results[chunk_id]; result.resize(segcore::DefaultElementPerChunk); auto chunk = vec.get_chunk(chunk_id); @@ -125,32 +126,32 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { switch (op) { case OpType::Equal: { auto index_func = [val](Index* index) { return index->In(1, &val); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); }); } case OpType::NotEqual: { auto index_func = [val](Index* index) { return index->NotIn(1, &val); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); }); } case OpType::GreaterEqual: { auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); }); } case OpType::GreaterThan: { auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); }); } case OpType::LessEqual: { auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); }); } case OpType::LessThan: { auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); }; - return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); }); + return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); }); } default: { PanicInfo("unsupported range node"); @@ -166,16 +167,16 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType { if (false) { } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) { auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); }; - return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); }); + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); }); } else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) { auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); }; - return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); }); + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); }); } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) { auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); }; - return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); }); + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); }); } else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) { auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); }; - return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); }); + return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); }); } else { PanicInfo("unsupported range node"); } @@ -225,79 +226,4 @@ ExecExprVisitor::visit(RangeExpr& expr) { ret_ = std::move(ret); } -template -auto -ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType { - auto& expr = static_cast&>(expr_raw); - auto& records = segment_.get_insert_record(); - auto data_type = expr.data_type_; - auto& schema = segment_.get_schema(); - auto field_offset_opt = schema.get_offset(expr.field_id_); - Assert(field_offset_opt); - auto field_offset = field_offset_opt.value(); - auto& field_meta = schema[field_offset]; - auto vec_ptr = records.get_entity(field_offset); - auto& vec = *vec_ptr; - auto num_chunk = vec.num_chunk(); - RetType bitsets; - - auto N = records.ack_responder_.GetAck(); - - // small batch - for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) { - auto& chunk = vec.get_chunk(chunk_id); - - auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk - : segcore::DefaultElementPerChunk; - - boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk); - for (int i = 0; i < size; ++i) { - auto value = chunk[i]; - bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value); - bitset[i] = is_in; - } - bitsets.emplace_back(std::move(bitset)); - } - return bitsets; -} - -void -ExecExprVisitor::visit(TermExpr& expr) { - auto& field_meta = segment_.get_schema()[expr.field_id_]; - Assert(expr.data_type_ == field_meta.get_data_type()); - RetType ret; - switch (expr.data_type_) { - case DataType::BOOL: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT8: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT16: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT32: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::INT64: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::FLOAT: { - ret = ExecTermVisitorImpl(expr); - break; - } - case DataType::DOUBLE: { - ret = ExecTermVisitorImpl(expr); - break; - } - default: - PanicInfo("unsupported"); - } - ret_ = std::move(ret); -} } // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index c7c0f9c65d..c00c7cd4f9 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -26,7 +26,7 @@ namespace impl { // WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ class ExecPlanNodeVisitor : PlanNodeVisitor { public: - using RetType = QueryResult; + using RetType = segcore::QueryResult; ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group) : segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) { } @@ -75,22 +75,7 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { void ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) { - // TODO: optimize here, remove the dynamic cast - assert(!ret_.has_value()); - auto segment = dynamic_cast(&segment_); - AssertInfo(segment, "support SegmentSmallIndex Only"); - RetType ret; - auto& ph = placeholder_group_.at(0); - auto src_data = ph.get_blob(); - auto num_queries = ph.num_of_queries_; - if (node.predicate_.has_value()) { - auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value()); - auto ptr = &bitmap; - BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret); - } else { - BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret); - } - ret_ = ret; + // TODO } } // namespace milvus::query diff --git a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp index 19dea7b40e..85216c5260 100644 --- a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp @@ -73,24 +73,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) { void ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) { - assert(!ret_); - auto& info = node.query_info_; - Json json_body{ - {"node_type", "BinaryVectorANNS"}, // - {"metric_type", info.metric_type_}, // - {"field_id_", info.field_id_}, // - {"topK", info.topK_}, // - {"search_params", info.search_params_}, // - {"placeholder_tag", node.placeholder_tag_}, // - }; - if (node.predicate_.has_value()) { - ShowExprVisitor expr_show; - Assert(node.predicate_.value()); - json_body["predicate"] = expr_show.call_child(node.predicate_->operator*()); - } else { - json_body["predicate"] = "None"; - } - ret_ = json_body; + // TODO } } // namespace milvus::query diff --git a/internal/core/src/segcore/Collection.cpp b/internal/core/src/segcore/Collection.cpp index a456877e31..2edf4a0d66 100644 --- a/internal/core/src/segcore/Collection.cpp +++ b/internal/core/src/segcore/Collection.cpp @@ -123,10 +123,9 @@ Collection::CreateIndex(std::string& index_config) { void Collection::parse() { if (collection_proto_.empty()) { - // TODO: remove hard code use unittests are ready std::cout << "WARN: Use default schema" << std::endl; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); schema_ = schema; return; diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index c168f811cf..74ece55121 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -196,7 +196,7 @@ class ConcurrentVectorImpl : public VectorBase { } ssize_t - num_chunk() const { + chunk_size() const { return chunks_.size(); } @@ -226,14 +226,8 @@ class ConcurrentVector : public ConcurrentVectorImpl { using ConcurrentVectorImpl::ConcurrentVectorImpl; }; -class VectorTrait {}; - -class FloatVector : public VectorTrait { - using embedded_type = float; -}; -class BinaryVector : public VectorTrait { - using embedded_type = uint8_t; -}; +class FloatVector {}; +class BinaryVector {}; template <> class ConcurrentVector : public ConcurrentVectorImpl { diff --git a/internal/core/src/segcore/IndexingEntry.cpp b/internal/core/src/segcore/IndexingEntry.cpp index aafa54da36..4930bf7f98 100644 --- a/internal/core/src/segcore/IndexingEntry.cpp +++ b/internal/core/src/segcore/IndexingEntry.cpp @@ -24,7 +24,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector auto source = dynamic_cast*>(vec_base); Assert(source); - auto chunk_size = source->num_chunk(); + auto chunk_size = source->chunk_size(); assert(ack_end <= chunk_size); auto conf = get_build_conf(); data_.grow_to_at_least(ack_end); @@ -85,9 +85,11 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record) template void ScalarIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) { + auto dim = field_meta_.get_dim(); + auto source = dynamic_cast*>(vec_base); Assert(source); - auto chunk_size = source->num_chunk(); + auto chunk_size = source->chunk_size(); assert(ack_end <= chunk_size); data_.grow_to_at_least(ack_end); for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) { diff --git a/internal/core/src/segcore/SegmentBase.h b/internal/core/src/segcore/SegmentBase.h index c1ddbf94a1..6fa0950e2e 100644 --- a/internal/core/src/segcore/SegmentBase.h +++ b/internal/core/src/segcore/SegmentBase.h @@ -24,7 +24,7 @@ namespace milvus { namespace segcore { // using engine::DataChunk; // using engine::DataChunkPtr; -using QueryResult = milvus::QueryResult; +using engine::QueryResult; struct RowBasedRawData { void* raw_data; // schema int sizeof_per_row; // alignment diff --git a/internal/core/src/segcore/SegmentNaive.cpp b/internal/core/src/segcore/SegmentNaive.cpp index e08f9a290e..1091c6df25 100644 --- a/internal/core/src/segcore/SegmentNaive.cpp +++ b/internal/core/src/segcore/SegmentNaive.cpp @@ -467,16 +467,16 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) { auto dim = field.get_dim(); auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode); - auto chunk_size = record_.uids_.num_chunk(); + auto chunk_size = record_.uids_.chunk_size(); auto& uids = record_.uids_; auto entities = record_.get_entity(offset); std::vector datasets; - for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) { + for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) { auto entities_chunk = entities->get_chunk(chunk_id).data(); - int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk - : DefaultElementPerChunk; + int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk + : DefaultElementPerChunk; datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk)); } for (auto& ds : datasets) { diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index 565182983a..cae5160ff8 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -241,10 +241,10 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) { auto entities = record_.get_entity(offset); std::vector datasets; - for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) { + for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) { auto entities_chunk = entities->get_chunk(chunk_id).data(); - int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk - : DefaultElementPerChunk; + int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk + : DefaultElementPerChunk; datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk)); } for (auto& ds : datasets) { diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index e151f3b68b..8c6786c4f0 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -42,7 +42,7 @@ DeleteSegment(CSegmentBase segment) { void DeleteQueryResult(CQueryResult query_result) { - auto res = (milvus::QueryResult*)query_result; + auto res = (milvus::segcore::QueryResult*)query_result; delete res; } @@ -134,7 +134,7 @@ Search(CSegmentBase c_segment, placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]); } - auto query_result = std::make_unique(); + auto query_result = std::make_unique(); auto status = CStatus(); try { diff --git a/internal/core/src/utils/EasyAssert.cpp b/internal/core/src/utils/EasyAssert.cpp index c6b5ebf05b..0ee5a86246 100644 --- a/internal/core/src/utils/EasyAssert.cpp +++ b/internal/core/src/utils/EasyAssert.cpp @@ -42,11 +42,8 @@ EasyAssertInfo( [[noreturn]] void ThrowWithTrace(const std::exception& exception) { - if (typeid(exception) == typeid(WrappedRuntimError)) { - throw exception; - } auto err_msg = exception.what() + std::string("\n") + EasyStackTrace(); - throw WrappedRuntimError(err_msg); + throw std::runtime_error(err_msg); } } // namespace milvus::impl diff --git a/internal/core/src/utils/EasyAssert.h b/internal/core/src/utils/EasyAssert.h index 374636102f..21891791f5 100644 --- a/internal/core/src/utils/EasyAssert.h +++ b/internal/core/src/utils/EasyAssert.h @@ -11,7 +11,6 @@ #pragma once #include -#include #include #include #include @@ -23,10 +22,6 @@ void EasyAssertInfo( bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info); -class WrappedRuntimError : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - [[noreturn]] void ThrowWithTrace(const std::exception& exception); diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 29fef50fb7..ebf2ef375f 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -26,5 +26,4 @@ target_link_libraries(all_tests pthread milvus_utils ) - install (TARGETS all_tests DESTINATION unittest) diff --git a/internal/core/unittest/test_binary.cpp b/internal/core/unittest/test_binary.cpp index 3e1d81ea26..e5adf3a269 100644 --- a/internal/core/unittest/test_binary.cpp +++ b/internal/core/unittest/test_binary.cpp @@ -21,7 +21,7 @@ TEST(Binary, Insert) { int64_t num_queries = 10; int64_t topK = 5; auto schema = std::make_shared(); - schema->AddField("vecbin", DataType::VECTOR_BINARY, 128, MetricType::METRIC_Jaccard); + schema->AddField("vecbin", DataType::VECTOR_BINARY, 128); schema->AddField("age", DataType::INT64); auto dataset = DataGen(schema, N, 10); auto segment = CreateSegment(schema); diff --git a/internal/core/unittest/test_common.cpp b/internal/core/unittest/test_common.cpp deleted file mode 100644 index 2be5c0bee0..0000000000 --- a/internal/core/unittest/test_common.cpp +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include diff --git a/internal/core/unittest/test_concurrent_vector.cpp b/internal/core/unittest/test_concurrent_vector.cpp index fe6687f890..ce0acfb11c 100644 --- a/internal/core/unittest/test_concurrent_vector.cpp +++ b/internal/core/unittest/test_concurrent_vector.cpp @@ -52,7 +52,7 @@ TEST(ConcurrentVector, TestSingle) { c_vec.set_data(total_count, vec.data(), insert_size); total_count += insert_size; } - ASSERT_EQ(c_vec.num_chunk(), (total_count + 31) / 32); + ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32); for (int i = 0; i < total_count; ++i) { for (int d = 0; d < dim; ++d) { auto std_data = d + i * dim; diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 2d968f09cf..c45909faf3 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -98,49 +98,7 @@ TEST(Expr, Range) { } })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); - schema->AddField("age", DataType::INT32); - auto plan = CreatePlan(*schema, dsl_string); - ShowPlanNodeVisitor shower; - Assert(plan->tag2field_.at("$0") == "fakevec"); - auto out = shower.call_child(*plan->plan_node_); - std::cout << out.dump(4); -} - -TEST(Expr, RangeBinary) { - SUCCEED(); - using namespace milvus; - using namespace milvus::query; - using namespace milvus::segcore; - std::string dsl_string = R"( -{ - "bool": { - "must": [ - { - "range": { - "age": { - "GT": 1, - "LT": 100 - } - } - }, - { - "vector": { - "fakevec": { - "metric_type": "Jaccard", - "params": { - "nprobe": 10 - }, - "query": "$0", - "topk": 10 - } - } - } - ] - } -})"; - auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); auto plan = CreatePlan(*schema, dsl_string); ShowPlanNodeVisitor shower; @@ -182,7 +140,7 @@ TEST(Expr, InvalidRange) { } })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string)); } @@ -221,7 +179,7 @@ TEST(Expr, InvalidDSL) { })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string)); } @@ -231,7 +189,7 @@ TEST(Expr, ShowExecutor) { using namespace milvus::segcore; auto node = std::make_unique(); auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); int64_t num_queries = 100L; auto raw_data = DataGen(schema, num_queries); auto& info = node->query_info_; @@ -290,7 +248,7 @@ TEST(Expr, TestRange) { } })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); auto seg = CreateSegment(schema); @@ -321,88 +279,7 @@ TEST(Expr, TestRange) { auto ans = final[vec_id][offset]; auto val = age_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } - } -} - -TEST(Expr, TestTerm) { - using namespace milvus::query; - using namespace milvus::segcore; - auto vec_2k_3k = [] { - std::string buf = "["; - for (int i = 2000; i < 3000 - 1; ++i) { - buf += std::to_string(i) + ", "; - } - buf += std::to_string(2999) + "]"; - return buf; - }(); - - std::vector>> testcases = { - {R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }}, - {R"([2000])", [](int v) { return v == 2000; }}, - {R"([3000])", [](int v) { return v == 3000; }}, - {vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }}, - }; - - std::string dsl_string_tmp = R"( -{ - "bool": { - "must": [ - { - "term": { - "age": @@@@ - } - }, - { - "vector": { - "fakevec": { - "metric_type": "L2", - "params": { - "nprobe": 10 - }, - "query": "$0", - "topk": 10 - } - } - } - ] - } -})"; - auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); - schema->AddField("age", DataType::INT32); - - auto seg = CreateSegment(schema); - int N = 10000; - std::vector age_col; - int num_iters = 100; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); - auto new_age_col = raw_data.get_col(1); - age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end()); - seg->PreInsert(N); - seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); - } - - auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor(*seg_promote); - for (auto [clause, ref_func] : testcases) { - auto loc = dsl_string_tmp.find("@@@@"); - auto dsl_string = dsl_string_tmp; - dsl_string.replace(loc, 4, clause); - auto plan = CreatePlan(*schema, dsl_string); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); - EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk)); - - for (int i = 0; i < N * num_iters; ++i) { - auto vec_id = i / DefaultElementPerChunk; - auto offset = i % DefaultElementPerChunk; - auto ans = final[vec_id][offset]; - - auto val = age_col[i]; - auto ref = ref_func(val); + auto ref = !ref_func(val); ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; } } diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index e663a69e69..266549d01c 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -235,14 +235,14 @@ TEST(Indexing, IVFFlatNM) { } } -TEST(Indexing, BinaryBruteForce) { +TEST(Indexing, DISABLED_BinaryBruteForce) { int64_t N = 100000; int64_t num_queries = 10; int64_t topk = 5; - int64_t dim = 512; + int64_t dim = 64; auto result_count = topk * num_queries; auto schema = std::make_shared(); - schema->AddField("vecbin", DataType::VECTOR_BINARY, dim, MetricType::METRIC_Jaccard); + schema->AddField("vecbin", DataType::VECTOR_BINARY, dim); schema->AddField("age", DataType::INT64); auto dataset = DataGen(schema, N, 10); vector distances(result_count); @@ -250,16 +250,8 @@ TEST(Indexing, BinaryBruteForce) { auto bin_vec = dataset.get_col(0); auto line_sizeof = schema->operator[](0).get_sizeof(); auto query_data = 1024 * line_sizeof + bin_vec.data(); - query::dataset::BinaryQueryDataset query_dataset{ - faiss::MetricType::METRIC_Jaccard, // - num_queries, // - topk, // - line_sizeof, // - query_data // - }; - - query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, distances.data(), ids.data()); - + query::BinarySearchBruteForce(faiss::MetricType::METRIC_Jaccard, line_sizeof, bin_vec.data(), N, topk, num_queries, + query_data, distances.data(), ids.data()); QueryResult qr; qr.num_queries_ = num_queries; qr.topK_ = topk; @@ -272,78 +264,76 @@ TEST(Indexing, BinaryBruteForce) { [ [ "1024->0.000000", - "43190->0.578804", - "5255->0.586207", - "23247->0.586486", - "4936->0.588889" + "86966->0.395349", + "24843->0.404762", + "13806->0.416667", + "44313->0.421053" ], [ "1025->0.000000", - "15147->0.562162", - "49910->0.564304", - "67435->0.567867", - "38292->0.569921" + "14226->0.348837", + "1488->0.365854", + "47337->0.377778", + "20913->0.377778" ], [ "1026->0.000000", - "15332->0.569061", - "56391->0.572559", - "17187->0.572603", - "26988->0.573771" + "81882->0.386364", + "9215->0.409091", + "95024->0.409091", + "54987->0.414634" ], [ "1027->0.000000", - "4502->0.559585", - "25879->0.566234", - "66937->0.566489", - "21228->0.566845" + "68981->0.394737", + "75528->0.404762", + "68794->0.405405", + "21975->0.425000" ], [ "1028->0.000000", - "38490->0.578804", - "12946->0.581717", - "31677->0.582173", - "94474->0.583569" + "90290->0.375000", + "34309->0.394737", + "58559->0.400000", + "33865->0.400000" ], [ "1029->0.000000", - "59011->0.551630", - "82575->0.555263", - "42914->0.561828", - "23705->0.564171" + "62722->0.388889", + "89070->0.394737", + "18528->0.414634", + "94971->0.421053" ], [ "1030->0.000000", - "39782->0.579946", - "65553->0.589947", - "82154->0.590028", - "13374->0.590164" + "67402->0.333333", + "3988->0.347826", + "86376->0.354167", + "84381->0.361702" ], [ "1031->0.000000", - "47826->0.582873", - "72669->0.587432", - "334->0.588076", - "80652->0.589333" + "81569->0.325581", + "12715->0.347826", + "40332->0.363636", + "21037->0.372093" ], [ "1032->0.000000", - "31968->0.573034", - "63545->0.575758", - "76913->0.575916", - "6286->0.576000" + "60536->0.428571", + "93293->0.432432", + "70969->0.435897", + "64048->0.450000" ], [ "1033->0.000000", - "95635->0.570248", - "93439->0.574866", - "6709->0.578534", - "6367->0.579634" + "99022->0.394737", + "11763->0.405405", + "50073->0.428571", + "97118->0.428571" ] ] ] )"); - auto json_str = json.dump(2); - auto ref_str = ref.dump(2); - ASSERT_EQ(json_str, ref_str); + ASSERT_EQ(json, ref); } diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 5bb3a75a4d..f25f3ddf62 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -72,7 +72,7 @@ TEST(Query, ShowExecutor) { using namespace milvus; auto node = std::make_unique(); auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); int64_t num_queries = 100L; auto raw_data = DataGen(schema, num_queries); auto& info = node->query_info_; @@ -98,7 +98,7 @@ TEST(Query, DSL) { "must": [ { "vector": { - "fakevec": { + "Vec": { "metric_type": "L2", "params": { "nprobe": 10 @@ -113,7 +113,7 @@ TEST(Query, DSL) { })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); auto plan = CreatePlan(*schema, dsl_string); auto res = shower.call_child(*plan->plan_node_); @@ -123,7 +123,7 @@ TEST(Query, DSL) { { "bool": { "vector": { - "fakevec": { + "Vec": { "metric_type": "L2", "params": { "nprobe": 10 @@ -159,7 +159,7 @@ TEST(Query, ParsePlaceholderGroup) { })"; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); auto plan = CreatePlan(*schema, dsl_string); int64_t num_queries = 100000; int dim = 16; @@ -172,7 +172,7 @@ TEST(Query, ExecWithPredicate) { using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::FLOAT); std::string dsl = R"({ "bool": { @@ -217,8 +217,8 @@ TEST(Query, ExecWithPredicate) { int topk = 5; Json json = QueryResultToJson(qr); - auto ref = Json::parse(R"( -[ + + auto ref = Json::parse(R"([ [ [ "980486->3.149221", @@ -257,14 +257,15 @@ TEST(Query, ExecWithPredicate) { ] ] ])"); - ASSERT_EQ(json.dump(2), ref.dump(2)); + + ASSERT_EQ(json, ref); } -TEST(Query, ExecWithoutPredicate) { +TEST(Query, ExecWihtoutPredicate) { using namespace milvus::query; using namespace milvus::segcore; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::FLOAT); std::string dsl = R"({ "bool": { @@ -300,49 +301,18 @@ TEST(Query, ExecWithoutPredicate) { segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); std::vector> results; int topk = 5; - auto json = QueryResultToJson(qr); - auto ref = Json::parse(R"( -[ - [ - [ - "980486->3.149221", - "318367->3.661235", - "302798->4.553688", - "321424->4.757450", - "565529->5.083780" - ], - [ - "233390->7.931535", - "238958->8.109344", - "230645->8.439169", - "901939->8.658772", - "380328->8.731251" - ], - [ - "749862->3.398494", - "701321->3.632437", - "897246->3.749835", - "750683->3.897577", - "105995->4.073595" - ], - [ - "138274->3.454446", - "124548->3.783290", - "840855->4.782170", - "936719->5.026924", - "709627->5.063170" - ], - [ - "810401->3.926393", - "46575->4.054171", - "201740->4.274491", - "669040->4.399628", - "231500->4.831223" - ] - ] -] -)"); - ASSERT_EQ(json.dump(2), ref.dump(2)); + for (int q = 0; q < num_queries; ++q) { + std::vector result; + for (int k = 0; k < topk; ++k) { + int index = q * topk + k; + result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" + + std::to_string(qr.result_distances_[index])); + } + results.emplace_back(std::move(result)); + } + + Json json{results}; + std::cout << json.dump(2); } TEST(Query, FillSegment) { @@ -361,9 +331,6 @@ TEST(Query, FillSegment) { auto param = field->add_type_params(); param->set_key("dim"); param->set_value("16"); - auto iparam = field->add_index_params(); - iparam->set_key("metric_type"); - iparam->set_value("L2"); } { @@ -425,57 +392,3 @@ TEST(Query, FillSegment) { ++std_index; } } - -TEST(Query, ExecWithPredicateBinary) { - using namespace milvus::query; - using namespace milvus::segcore; - auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard); - schema->AddField("age", DataType::FLOAT); - std::string dsl = R"({ - "bool": { - "must": [ - { - "range": { - "age": { - "GE": -1, - "LT": 1 - } - } - }, - { - "vector": { - "fakevec": { - "metric_type": "Jaccard", - "params": { - "nprobe": 10 - }, - "query": "$0", - "topk": 5 - } - } - } - ] - } - })"; - int64_t N = 1000 * 1000; - auto dataset = DataGen(schema, N); - auto segment = std::make_unique(schema); - segment->PreInsert(N); - segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); - auto vec_ptr = dataset.get_col(0); - - auto plan = CreatePlan(*schema, dsl); - auto num_queries = 5; - auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(num_queries, 512, vec_ptr.data() + 1024 * 512 / 8); - auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - QueryResult qr; - Timestamp time = 1000000; - std::vector ph_group_arr = {ph_group.get()}; - segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); - int topk = 5; - - Json json = QueryResultToJson(qr); - std::cout << json.dump(2); - // ASSERT_EQ(json.dump(2), ref.dump(2)); -} diff --git a/internal/core/unittest/test_segcore.cpp b/internal/core/unittest/test_segcore.cpp index 1e68c43c1a..dfe744a71f 100644 --- a/internal/core/unittest/test_segcore.cpp +++ b/internal/core/unittest/test_segcore.cpp @@ -63,7 +63,7 @@ TEST(SegmentCoreTest, NormalDistributionTest) { using namespace milvus::segcore; using namespace milvus::engine; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); int N = 1000 * 1000; auto [raw_data, timestamps, uids] = generate_data(N); @@ -76,7 +76,7 @@ TEST(SegmentCoreTest, MockTest) { using namespace milvus::segcore; using namespace milvus::engine; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); std::vector raw_data; std::vector timestamps; @@ -116,7 +116,7 @@ TEST(SegmentCoreTest, SmallIndex) { using namespace milvus::segcore; using namespace milvus::engine; auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); int N = 1024 * 1024; auto data = DataGen(schema, N); diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 7725eb5ebd..4ce51c3612 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -31,14 +31,6 @@ struct GeneratedData { memcpy(ret.data(), target.data(), target.size()); return ret; } - template - auto - get_mutable_col(int index) { - auto& target = cols_.at(index); - assert(target.size() == row_ids_.size() * sizeof(T)); - auto ptr = reinterpret_cast(target.data()); - return ptr; - } private: GeneratedData() = default; @@ -66,9 +58,6 @@ GeneratedData::generate_rows(int N, SchemaPtr schema) { } } rows_ = std::move(result); - raw_.raw_data = rows_.data(); - raw_.sizeof_per_row = schema->get_total_sizeof(); - raw_.count = N; } inline GeneratedData @@ -140,12 +129,14 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) { } GeneratedData res; res.cols_ = std::move(cols); + res.generate_rows(N, schema); for (int i = 0; i < N; ++i) { res.row_ids_.push_back(i); res.timestamps_.push_back(i); } - - res.generate_rows(N, schema); + res.raw_.raw_data = res.rows_.data(); + res.raw_.sizeof_per_row = schema->get_total_sizeof(); + res.raw_.count = N; return std::move(res); } @@ -176,7 +167,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42 ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); - value->set_type(ser::PlaceholderType::VECTOR_BINARY); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); std::default_random_engine e(seed); for (int i = 0; i < num_queries; ++i) { std::vector vec; @@ -184,27 +175,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42 vec.push_back(e()); } // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); - value->add_values(vec.data(), vec.size()); - } - return raw_group; -} - -inline auto -CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries, int64_t dim, const uint8_t* ptr) { - assert(dim % 8 == 0); - namespace ser = milvus::proto::service; - ser::PlaceholderGroup raw_group; - auto value = raw_group.add_placeholders(); - value->set_tag("$0"); - value->set_type(ser::PlaceholderType::VECTOR_BINARY); - for (int i = 0; i < num_queries; ++i) { - std::vector vec; - for (int d = 0; d < dim / 8; ++d) { - vec.push_back(*ptr); - ++ptr; - } - // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); - value->add_values(vec.data(), vec.size()); + value->add_values(vec.data(), vec.size() * sizeof(float)); } return raw_group; } diff --git a/internal/kv/mockkv/mock_etcd.go b/internal/kv/mockkv/mock_etcd.go new file mode 100644 index 0000000000..932a7bdf2b --- /dev/null +++ b/internal/kv/mockkv/mock_etcd.go @@ -0,0 +1,14 @@ +package mockkv + +import ( + memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem" +) + +// use MemoryKV to mock EtcdKV +func NewEtcdKV() *memkv.MemoryKV { + return memkv.NewMemoryKV() +} + +func NewMemoryKV() *memkv.MemoryKV { + return memkv.NewMemoryKV() +} diff --git a/internal/master/collection_task_test.go b/internal/master/collection_task_test.go index 772a65ca88..08d9f95d79 100644 --- a/internal/master/collection_task_test.go +++ b/internal/master/collection_task_test.go @@ -19,7 +19,6 @@ import ( func TestMaster_CollectionTask(t *testing.T) { Init() - refreshMasterAddress() ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -65,10 +64,10 @@ func TestMaster_CollectionTask(t *testing.T) { svr, err := CreateServer(ctx) assert.Nil(t, err) - err = svr.Run(int64(Params.Port)) + err = svr.Run(10002) assert.Nil(t, err) - conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer conn.Close() diff --git a/internal/master/config_task_test.go b/internal/master/config_task_test.go index 5de544c156..2cb789aa04 100644 --- a/internal/master/config_task_test.go +++ b/internal/master/config_task_test.go @@ -16,7 +16,6 @@ import ( func TestMaster_ConfigTask(t *testing.T) { Init() - refreshMasterAddress() ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -60,11 +59,11 @@ func TestMaster_ConfigTask(t *testing.T) { svr, err := CreateServer(ctx) require.Nil(t, err) - err = svr.Run(int64(Params.Port)) + err = svr.Run(10002) defer svr.Close() require.Nil(t, err) - conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock()) require.Nil(t, err) defer conn.Close() diff --git a/internal/master/global_allocator_test.go b/internal/master/global_allocator_test.go index 35abae57ec..f33b934dc2 100644 --- a/internal/master/global_allocator_test.go +++ b/internal/master/global_allocator_test.go @@ -1,6 +1,7 @@ package master import ( + "os" "testing" "time" @@ -11,6 +12,16 @@ import ( var gTestTsoAllocator Allocator var gTestIDAllocator *GlobalIDAllocator +func TestMain(m *testing.M) { + Params.Init() + + etcdAddr := Params.EtcdAddress + gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso")) + gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid")) + exitCode := m.Run() + os.Exit(exitCode) +} + func TestGlobalTSOAllocator_Initialize(t *testing.T) { err := gTestTsoAllocator.Initialize() assert.Nil(t, err) diff --git a/internal/master/grpc_service_test.go b/internal/master/grpc_service_test.go index ec7f2ce9fe..644a84999c 100644 --- a/internal/master/grpc_service_test.go +++ b/internal/master/grpc_service_test.go @@ -17,7 +17,6 @@ import ( func TestMaster_CreateCollection(t *testing.T) { Init() - refreshMasterAddress() ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -63,10 +62,10 @@ func TestMaster_CreateCollection(t *testing.T) { svr, err := CreateServer(ctx) assert.Nil(t, err) - err = svr.Run(int64(Params.Port)) + err = svr.Run(10001) assert.Nil(t, err) - conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "127.0.0.1:10001", grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer conn.Close() diff --git a/internal/master/master_test.go b/internal/master/master_test.go index 5af957c527..c3cdbbdfbd 100644 --- a/internal/master/master_test.go +++ b/internal/master/master_test.go @@ -4,12 +4,9 @@ import ( "context" "log" "math/rand" - "os" "strconv" "testing" - "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" ms "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -23,42 +20,6 @@ import ( "google.golang.org/grpc" ) -var testPORT = 53200 - -func genMasterTestPort() int64 { - testPORT++ - return int64(testPORT) -} - -func refreshMasterAddress() { - masterPort := genMasterTestPort() - Params.Port = int(masterPort) - masterAddr := makeMasterAddress(masterPort) - Params.Address = masterAddr -} - -func makeMasterAddress(port int64) string { - masterAddr := "127.0.0.1:" + strconv.FormatInt(port, 10) - return masterAddr -} - -func makeNewChannalNames(names []string, suffix string) []string { - var ret []string - for _, name := range names { - ret = append(ret, name+suffix) - } - return ret -} - -func refreshChannelNames() { - suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10) - Params.DDChannelNames = makeNewChannalNames(Params.DDChannelNames, suffix) - Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(Params.WriteNodeTimeTickChannelNames, suffix) - Params.InsertChannelNames = makeNewChannalNames(Params.InsertChannelNames, suffix) - Params.K2SChannelNames = makeNewChannalNames(Params.K2SChannelNames, suffix) - Params.ProxyTimeTickChannelNames = makeNewChannalNames(Params.ProxyTimeTickChannelNames, suffix) -} - func receiveTimeTickMsg(stream *ms.MsgStream) bool { for { result := (*stream).Consume() @@ -76,25 +37,13 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { return &msgPack } -func TestMain(m *testing.M) { - Init() - refreshMasterAddress() - refreshChannelNames() - etcdAddr := Params.EtcdAddress - gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso")) - gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid")) - exitCode := m.Run() - os.Exit(exitCode) -} - func TestMaster(t *testing.T) { Init() - refreshMasterAddress() pulsarAddr := Params.PulsarAddress - Params.ProxyIDList = []UniqueID{0} - //Param + // Creates server. ctx, cancel := context.WithCancel(context.Background()) + svr, err := CreateServer(ctx) if err != nil { log.Print("create server failed", zap.Error(err)) @@ -149,7 +98,7 @@ func TestMaster(t *testing.T) { var k2sMsgstream ms.MsgStream = k2sMs assert.True(t, receiveTimeTickMsg(&k2sMsgstream)) - conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "127.0.0.1:53100", grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer conn.Close() diff --git a/internal/master/param_table.go b/internal/master/param_table.go index 2fa2f59d17..970f7cbbe9 100644 --- a/internal/master/param_table.go +++ b/internal/master/param_table.go @@ -55,8 +55,19 @@ var Params ParamTable func (p *ParamTable) Init() { // load yaml p.BaseTable.Init() - - err := p.LoadYaml("advanced/master.yaml") + err := p.LoadYaml("milvus.yaml") + if err != nil { + panic(err) + } + err = p.LoadYaml("advanced/channel.yaml") + if err != nil { + panic(err) + } + err = p.LoadYaml("advanced/master.yaml") + if err != nil { + panic(err) + } + err = p.LoadYaml("advanced/common.yaml") if err != nil { panic(err) } @@ -104,7 +115,15 @@ func (p *ParamTable) initAddress() { } func (p *ParamTable) initPort() { - p.Port = p.ParseInt("master.port") + masterPort, err := p.Load("master.port") + if err != nil { + panic(err) + } + port, err := strconv.Atoi(masterPort) + if err != nil { + panic(err) + } + p.Port = port } func (p *ParamTable) initEtcdAddress() { @@ -148,40 +167,117 @@ func (p *ParamTable) initKvRootPath() { } func (p *ParamTable) initTopicNum() { - iRangeStr, err := p.Load("msgChannel.channelRange.insert") + insertChannelRange, err := p.Load("msgChannel.channelRange.insert") if err != nil { panic(err) } - rangeSlice := paramtable.ConvertRangeToIntRange(iRangeStr, ",") - p.TopicNum = rangeSlice[1] - rangeSlice[0] + + channelRange := strings.Split(insertChannelRange, ",") + if len(channelRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(channelRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(channelRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + p.TopicNum = channelEnd } func (p *ParamTable) initSegmentSize() { - p.SegmentSize = p.ParseFloat("master.segment.size") + threshold, err := p.Load("master.segment.size") + if err != nil { + panic(err) + } + segmentThreshold, err := strconv.ParseFloat(threshold, 64) + if err != nil { + panic(err) + } + p.SegmentSize = segmentThreshold } func (p *ParamTable) initSegmentSizeFactor() { - p.SegmentSizeFactor = p.ParseFloat("master.segment.sizeFactor") + segFactor, err := p.Load("master.segment.sizeFactor") + if err != nil { + panic(err) + } + factor, err := strconv.ParseFloat(segFactor, 64) + if err != nil { + panic(err) + } + p.SegmentSizeFactor = factor } func (p *ParamTable) initDefaultRecordSize() { - p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord") + size, err := p.Load("master.segment.defaultSizePerRecord") + if err != nil { + panic(err) + } + res, err := strconv.ParseInt(size, 10, 64) + if err != nil { + panic(err) + } + p.DefaultRecordSize = res } func (p *ParamTable) initMinSegIDAssignCnt() { - p.MinSegIDAssignCnt = p.ParseInt64("master.segment.minIDAssignCnt") + size, err := p.Load("master.segment.minIDAssignCnt") + if err != nil { + panic(err) + } + res, err := strconv.ParseInt(size, 10, 64) + if err != nil { + panic(err) + } + p.MinSegIDAssignCnt = res } func (p *ParamTable) initMaxSegIDAssignCnt() { - p.MaxSegIDAssignCnt = p.ParseInt64("master.segment.maxIDAssignCnt") + size, err := p.Load("master.segment.maxIDAssignCnt") + if err != nil { + panic(err) + } + res, err := strconv.ParseInt(size, 10, 64) + if err != nil { + panic(err) + } + p.MaxSegIDAssignCnt = res } func (p *ParamTable) initSegIDAssignExpiration() { - p.SegIDAssignExpiration = p.ParseInt64("master.segment.IDAssignExpiration") + duration, err := p.Load("master.segment.IDAssignExpiration") + if err != nil { + panic(err) + } + res, err := strconv.ParseInt(duration, 10, 64) + if err != nil { + panic(err) + } + p.SegIDAssignExpiration = res } func (p *ParamTable) initQueryNodeNum() { - p.QueryNodeNum = len(p.QueryNodeIDList()) + id, err := p.Load("nodeID.queryNodeIDList") + if err != nil { + panic(err) + } + ids := strings.Split(id, ",") + for _, i := range ids { + _, err := strconv.ParseInt(i, 10, 64) + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + } + p.QueryNodeNum = len(ids) } func (p *ParamTable) initQueryNodeStatsChannelName() { @@ -193,7 +289,20 @@ func (p *ParamTable) initQueryNodeStatsChannelName() { } func (p *ParamTable) initProxyIDList() { - p.ProxyIDList = p.BaseTable.ProxyIDList() + id, err := p.Load("nodeID.proxyIDList") + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + ids := strings.Split(id, ",") + idList := make([]typeutil.UniqueID, 0, len(ids)) + for _, i := range ids { + v, err := strconv.ParseInt(i, 10, 64) + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + idList = append(idList, typeutil.UniqueID(v)) + } + p.ProxyIDList = idList } func (p *ParamTable) initProxyTimeTickChannelNames() { @@ -238,7 +347,20 @@ func (p *ParamTable) initSoftTimeTickBarrierInterval() { } func (p *ParamTable) initWriteNodeIDList() { - p.WriteNodeIDList = p.BaseTable.WriteNodeIDList() + id, err := p.Load("nodeID.writeNodeIDList") + if err != nil { + log.Panic(err) + } + ids := strings.Split(id, ",") + idlist := make([]typeutil.UniqueID, 0, len(ids)) + for _, i := range ids { + v, err := strconv.ParseInt(i, 10, 64) + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + idlist = append(idlist, typeutil.UniqueID(v)) + } + p.WriteNodeIDList = idlist } func (p *ParamTable) initWriteNodeTimeTickChannelNames() { @@ -263,57 +385,81 @@ func (p *ParamTable) initWriteNodeTimeTickChannelNames() { } func (p *ParamTable) initDDChannelNames() { - prefix, err := p.Load("msgChannel.chanNamePrefix.dataDefinition") + ch, err := p.Load("msgChannel.chanNamePrefix.dataDefinition") if err != nil { - panic(err) + log.Fatal(err) } - prefix += "-" - iRangeStr, err := p.Load("msgChannel.channelRange.dataDefinition") + id, err := p.Load("nodeID.queryNodeIDList") if err != nil { - panic(err) + log.Panicf("load query node id list error, %s", err.Error()) } - channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",") - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + ids := strings.Split(id, ",") + channels := make([]string, 0, len(ids)) + for _, i := range ids { + _, err := strconv.ParseInt(i, 10, 64) + if err != nil { + log.Panicf("load query node id list error, %s", err.Error()) + } + channels = append(channels, ch+"-"+i) } - p.DDChannelNames = ret + p.DDChannelNames = channels } func (p *ParamTable) initInsertChannelNames() { - prefix, err := p.Load("msgChannel.chanNamePrefix.insert") + ch, err := p.Load("msgChannel.chanNamePrefix.insert") + if err != nil { + log.Fatal(err) + } + channelRange, err := p.Load("msgChannel.channelRange.insert") if err != nil { panic(err) } - prefix += "-" - iRangeStr, err := p.Load("msgChannel.channelRange.insert") + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",") - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) } - p.InsertChannelNames = ret + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + p.InsertChannelNames = channels } func (p *ParamTable) initK2SChannelNames() { - prefix, err := p.Load("msgChannel.chanNamePrefix.k2s") + ch, err := p.Load("msgChannel.chanNamePrefix.k2s") if err != nil { - panic(err) + log.Fatal(err) } - prefix += "-" - iRangeStr, err := p.Load("msgChannel.channelRange.k2s") + id, err := p.Load("nodeID.writeNodeIDList") if err != nil { - panic(err) + log.Panicf("load write node id list error, %s", err.Error()) } - channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",") - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + ids := strings.Split(id, ",") + channels := make([]string, 0, len(ids)) + for _, i := range ids { + _, err := strconv.ParseInt(i, 10, 64) + if err != nil { + log.Panicf("load write node id list error, %s", err.Error()) + } + channels = append(channels, ch+"-"+i) } - p.K2SChannelNames = ret + p.K2SChannelNames = channels } func (p *ParamTable) initMaxPartitionNum() { diff --git a/internal/master/param_table_test.go b/internal/master/param_table_test.go index 8128c5f180..ed83c3af02 100644 --- a/internal/master/param_table_test.go +++ b/internal/master/param_table_test.go @@ -1,7 +1,6 @@ package master import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -12,111 +11,133 @@ func TestParamTable_Init(t *testing.T) { } func TestParamTable_Address(t *testing.T) { + Params.Init() address := Params.Address assert.Equal(t, address, "localhost") } func TestParamTable_Port(t *testing.T) { + Params.Init() port := Params.Port assert.Equal(t, port, 53100) } func TestParamTable_MetaRootPath(t *testing.T) { + Params.Init() path := Params.MetaRootPath assert.Equal(t, path, "by-dev/meta") } func TestParamTable_KVRootPath(t *testing.T) { + Params.Init() path := Params.KvRootPath assert.Equal(t, path, "by-dev/kv") } func TestParamTable_TopicNum(t *testing.T) { + Params.Init() num := Params.TopicNum - fmt.Println("TopicNum:", num) + assert.Equal(t, num, 1) } func TestParamTable_SegmentSize(t *testing.T) { + Params.Init() size := Params.SegmentSize assert.Equal(t, size, float64(512)) } func TestParamTable_SegmentSizeFactor(t *testing.T) { + Params.Init() factor := Params.SegmentSizeFactor assert.Equal(t, factor, 0.75) } func TestParamTable_DefaultRecordSize(t *testing.T) { + Params.Init() size := Params.DefaultRecordSize assert.Equal(t, size, int64(1024)) } func TestParamTable_MinSegIDAssignCnt(t *testing.T) { + Params.Init() cnt := Params.MinSegIDAssignCnt assert.Equal(t, cnt, int64(1024)) } func TestParamTable_MaxSegIDAssignCnt(t *testing.T) { + Params.Init() cnt := Params.MaxSegIDAssignCnt assert.Equal(t, cnt, int64(16384)) } func TestParamTable_SegIDAssignExpiration(t *testing.T) { + Params.Init() expiration := Params.SegIDAssignExpiration assert.Equal(t, expiration, int64(2000)) } func TestParamTable_QueryNodeNum(t *testing.T) { + Params.Init() num := Params.QueryNodeNum - fmt.Println("QueryNodeNum", num) + assert.Equal(t, num, 1) } func TestParamTable_QueryNodeStatsChannelName(t *testing.T) { + Params.Init() name := Params.QueryNodeStatsChannelName assert.Equal(t, name, "query-node-stats") } func TestParamTable_ProxyIDList(t *testing.T) { + Params.Init() ids := Params.ProxyIDList assert.Equal(t, len(ids), 1) assert.Equal(t, ids[0], int64(0)) } func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) { + Params.Init() names := Params.ProxyTimeTickChannelNames assert.Equal(t, len(names), 1) assert.Equal(t, names[0], "proxyTimeTick-0") } func TestParamTable_MsgChannelSubName(t *testing.T) { + Params.Init() name := Params.MsgChannelSubName assert.Equal(t, name, "master") } func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) { + Params.Init() interval := Params.SoftTimeTickBarrierInterval assert.Equal(t, interval, Timestamp(0x7d00000)) } func TestParamTable_WriteNodeIDList(t *testing.T) { + Params.Init() ids := Params.WriteNodeIDList assert.Equal(t, len(ids), 1) assert.Equal(t, ids[0], int64(3)) } func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) { + Params.Init() names := Params.WriteNodeTimeTickChannelNames assert.Equal(t, len(names), 1) assert.Equal(t, names[0], "writeNodeTimeTick-3") } func TestParamTable_InsertChannelNames(t *testing.T) { + Params.Init() names := Params.InsertChannelNames - assert.Equal(t, Params.TopicNum, len(names)) + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "insert-0") } func TestParamTable_K2SChannelNames(t *testing.T) { + Params.Init() names := Params.K2SChannelNames assert.Equal(t, len(names), 1) - assert.Equal(t, names[0], "k2s-0") + assert.Equal(t, names[0], "k2s-3") } diff --git a/internal/master/partition_task_test.go b/internal/master/partition_task_test.go index e0eef54dff..c6642cad27 100644 --- a/internal/master/partition_task_test.go +++ b/internal/master/partition_task_test.go @@ -2,6 +2,8 @@ package master import ( "context" + "math/rand" + "strconv" "testing" "github.com/golang/protobuf/proto" @@ -18,7 +20,6 @@ import ( func TestMaster_Partition(t *testing.T) { Init() - refreshMasterAddress() ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -65,12 +66,13 @@ func TestMaster_Partition(t *testing.T) { DefaultPartitionTag: "_default", } + port := 10000 + rand.Intn(1000) svr, err := CreateServer(ctx) assert.Nil(t, err) - err = svr.Run(int64(Params.Port)) + err = svr.Run(int64(port)) assert.Nil(t, err) - conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "127.0.0.1:"+strconv.Itoa(port), grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer conn.Close() diff --git a/internal/master/segment_manager_test.go b/internal/master/segment_manager_test.go index b54a1a7681..19209ed96e 100644 --- a/internal/master/segment_manager_test.go +++ b/internal/master/segment_manager_test.go @@ -35,7 +35,7 @@ var master *Master var masterCancelFunc context.CancelFunc func setup() { - Init() + Params.Init() etcdAddress := Params.EtcdAddress cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) @@ -218,8 +218,7 @@ func TestSegmentManager_SegmentStats(t *testing.T) { } func startupMaster() { - Init() - refreshMasterAddress() + Params.Init() etcdAddress := Params.EtcdAddress rootPath := "/test/root" ctx, cancel := context.WithCancel(context.TODO()) @@ -232,6 +231,7 @@ func startupMaster() { if err != nil { panic(err) } + Params = ParamTable{ Address: Params.Address, Port: Params.Port, @@ -272,7 +272,7 @@ func startupMaster() { if err != nil { panic(err) } - err = master.Run(int64(Params.Port)) + err = master.Run(10013) if err != nil { panic(err) @@ -289,7 +289,7 @@ func TestSegmentManager_RPC(t *testing.T) { defer shutdownMaster() ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - dialContext, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) + dialContext, err := grpc.DialContext(ctx, "127.0.0.1:10013", grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer dialContext.Close() client := masterpb.NewMasterClient(dialContext) diff --git a/internal/proxy/paramtable.go b/internal/proxy/paramtable.go index 3283b3d916..b0dee63eee 100644 --- a/internal/proxy/paramtable.go +++ b/internal/proxy/paramtable.go @@ -19,8 +19,19 @@ var Params ParamTable func (pt *ParamTable) Init() { pt.BaseTable.Init() - - err := pt.LoadYaml("advanced/proxy.yaml") + err := pt.LoadYaml("milvus.yaml") + if err != nil { + panic(err) + } + err = pt.LoadYaml("advanced/proxy.yaml") + if err != nil { + panic(err) + } + err = pt.LoadYaml("advanced/channel.yaml") + if err != nil { + panic(err) + } + err = pt.LoadYaml("advanced/common.yaml") if err != nil { panic(err) } @@ -37,24 +48,15 @@ func (pt *ParamTable) Init() { pt.Save("_proxyID", proxyIDStr) } -func (pt *ParamTable) NetworkPort() int { - return pt.ParseInt("proxy.port") -} - -func (pt *ParamTable) NetworkAddress() string { - addr, err := pt.Load("proxy.address") +func (pt *ParamTable) NetWorkAddress() string { + addr, err := pt.Load("proxy.network.address") if err != nil { panic(err) } - - hostName, _ := net.LookupHost(addr) - if len(hostName) <= 0 { - if ip := net.ParseIP(addr); ip == nil { - panic("invalid ip proxy.address") - } + if ip := net.ParseIP(addr); ip == nil { + panic("invalid ip proxy.network.address") } - - port, err := pt.Load("proxy.port") + port, err := pt.Load("proxy.network.port") if err != nil { panic(err) } @@ -86,6 +88,23 @@ func (pt *ParamTable) ProxyNum() int { return len(ret) } +func (pt *ParamTable) ProxyIDList() []UniqueID { + proxyIDStr, err := pt.Load("nodeID.proxyIDList") + if err != nil { + panic(err) + } + var ret []UniqueID + proxyIDs := strings.Split(proxyIDStr, ",") + for _, i := range proxyIDs { + v, err := strconv.Atoi(i) + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + ret = append(ret, UniqueID(v)) + } + return ret +} + func (pt *ParamTable) queryNodeNum() int { return len(pt.queryNodeIDList()) } @@ -131,6 +150,25 @@ func (pt *ParamTable) TimeTickInterval() time.Duration { return time.Duration(interval) * time.Millisecond } +func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int { + channelIDs := strings.Split(rangeStr, sep) + startStr := channelIDs[0] + endStr := channelIDs[1] + start, err := strconv.Atoi(startStr) + if err != nil { + panic(err) + } + end, err := strconv.Atoi(endStr) + if err != nil { + panic(err) + } + var ret []int + for i := start; i < end; i++ { + ret = append(ret, i) + } + return ret +} + func (pt *ParamTable) sliceIndex() int { proxyID := pt.ProxyID() proxyIDList := pt.ProxyIDList() @@ -152,7 +190,7 @@ func (pt *ParamTable) InsertChannelNames() []string { if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",") + channelIDs := pt.convertRangeToSlice(iRangeStr, ",") var ret []string for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) @@ -178,12 +216,19 @@ func (pt *ParamTable) DeleteChannelNames() []string { if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(dRangeStr, ",") + channelIDs := pt.convertRangeToSlice(dRangeStr, ",") var ret []string for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - return ret + proxyNum := pt.ProxyNum() + sep := len(channelIDs) / proxyNum + index := pt.sliceIndex() + if index == -1 { + panic("ProxyID not Match with Config") + } + start := index * sep + return ret[start : start+sep] } func (pt *ParamTable) K2SChannelNames() []string { @@ -196,12 +241,19 @@ func (pt *ParamTable) K2SChannelNames() []string { if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(k2sRangeStr, ",") + channelIDs := pt.convertRangeToSlice(k2sRangeStr, ",") var ret []string for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) } - return ret + proxyNum := pt.ProxyNum() + sep := len(channelIDs) / proxyNum + index := pt.sliceIndex() + if index == -1 { + panic("ProxyID not Match with Config") + } + start := index * sep + return ret[start : start+sep] } func (pt *ParamTable) SearchChannelNames() []string { @@ -209,17 +261,8 @@ func (pt *ParamTable) SearchChannelNames() []string { if err != nil { panic(err) } - prefix += "-" - sRangeStr, err := pt.Load("msgChannel.channelRange.search") - if err != nil { - panic(err) - } - channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",") - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) - } - return ret + prefix += "-0" + return []string{prefix} } func (pt *ParamTable) SearchResultChannelNames() []string { @@ -232,7 +275,7 @@ func (pt *ParamTable) SearchResultChannelNames() []string { if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",") + channelIDs := pt.convertRangeToSlice(sRangeStr, ",") var ret []string for _, ID := range channelIDs { ret = append(ret, prefix+strconv.Itoa(ID)) @@ -278,24 +321,144 @@ func (pt *ParamTable) DataDefinitionChannelNames() []string { return []string{prefix} } +func (pt *ParamTable) parseInt64(key string) int64 { + valueStr, err := pt.Load(key) + if err != nil { + panic(err) + } + value, err := strconv.Atoi(valueStr) + if err != nil { + panic(err) + } + return int64(value) +} + func (pt *ParamTable) MsgStreamInsertBufSize() int64 { - return pt.ParseInt64("proxy.msgStream.insert.bufSize") + return pt.parseInt64("proxy.msgStream.insert.bufSize") } func (pt *ParamTable) MsgStreamSearchBufSize() int64 { - return pt.ParseInt64("proxy.msgStream.search.bufSize") + return pt.parseInt64("proxy.msgStream.search.bufSize") } func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 { - return pt.ParseInt64("proxy.msgStream.searchResult.recvBufSize") + return pt.parseInt64("proxy.msgStream.searchResult.recvBufSize") } func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 { - return pt.ParseInt64("proxy.msgStream.searchResult.pulsarBufSize") + return pt.parseInt64("proxy.msgStream.searchResult.pulsarBufSize") } func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 { - return pt.ParseInt64("proxy.msgStream.timeTick.bufSize") + return pt.parseInt64("proxy.msgStream.timeTick.bufSize") +} + +func (pt *ParamTable) insertChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.insert") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.insert") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (pt *ParamTable) searchChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.search") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.search") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels +} + +func (pt *ParamTable) searchResultChannelNames() []string { + ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult") + if err != nil { + log.Fatal(err) + } + channelRange, err := pt.Load("msgChannel.channelRange.searchResult") + if err != nil { + panic(err) + } + + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels } func (pt *ParamTable) MaxNameLength() int64 { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 02995dbffc..0f38241dad 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -5,7 +5,6 @@ import ( "log" "math/rand" "net" - "strconv" "sync" "time" @@ -60,7 +59,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) { p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize()) p.queryMsgStream.SetPulsarClient(pulsarAddress) - p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames()) + p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames()) masterAddr := Params.MasterAddress() idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr) @@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) { p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize()) p.manipulationMsgStream.SetPulsarClient(pulsarAddress) - p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames()) + p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames()) repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false) } @@ -138,7 +137,7 @@ func (p *Proxy) AddCloseCallback(callbacks ...func()) { func (p *Proxy) grpcLoop() { defer p.proxyLoopWg.Done() - lis, err := net.Listen("tcp", ":"+strconv.Itoa(Params.NetworkPort())) + lis, err := net.Listen("tcp", Params.NetWorkAddress()) if err != nil { log.Fatalf("Proxy grpc server fatal error=%v", err) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2c4bda85ce..4eb56f3086 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log" - "math/rand" "os" "strconv" "strings" @@ -35,26 +34,8 @@ var masterServer *master.Master var testNum = 10 -func makeNewChannalNames(names []string, suffix string) []string { - var ret []string - for _, name := range names { - ret = append(ret, name+suffix) - } - return ret -} - -func refreshChannelNames() { - suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10) - master.Params.DDChannelNames = makeNewChannalNames(master.Params.DDChannelNames, suffix) - master.Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(master.Params.WriteNodeTimeTickChannelNames, suffix) - master.Params.InsertChannelNames = makeNewChannalNames(master.Params.InsertChannelNames, suffix) - master.Params.K2SChannelNames = makeNewChannalNames(master.Params.K2SChannelNames, suffix) - master.Params.ProxyTimeTickChannelNames = makeNewChannalNames(master.Params.ProxyTimeTickChannelNames, suffix) -} - func startMaster(ctx context.Context) { master.Init() - refreshChannelNames() etcdAddr := master.Params.EtcdAddress metaRootPath := master.Params.MetaRootPath @@ -100,7 +81,7 @@ func setup() { startMaster(ctx) startProxy(ctx) - proxyAddr := Params.NetworkAddress() + proxyAddr := Params.NetWorkAddress() addr := strings.Split(proxyAddr, ":") if addr[0] == "0.0.0.0" { proxyAddr = "127.0.0.1:" + addr[1] diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index b0c7debb29..b3a23272a8 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -364,7 +364,7 @@ func (sched *TaskScheduler) queryResultLoop() { unmarshal := msgstream.NewUnmarshalDispatcher() queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize()) queryResultMsgStream.SetPulsarClient(Params.PulsarAddress()) - queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(), + queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(), Params.ProxySubName(), unmarshal, Params.MsgStreamSearchResultPulsarBufSize()) diff --git a/internal/querynode/collection_replica.go b/internal/querynode/collection_replica.go index d4f39d628c..85e3370bd7 100644 --- a/internal/querynode/collection_replica.go +++ b/internal/querynode/collection_replica.go @@ -31,7 +31,7 @@ import ( * is up-to-date. */ type collectionReplica interface { - getTSafe() tSafe + getTSafe() *tSafe // collection getCollectionNum() int @@ -68,11 +68,11 @@ type collectionReplicaImpl struct { collections []*Collection segments map[UniqueID]*Segment - tSafe tSafe + tSafe *tSafe } //----------------------------------------------------------------------------------------------------- tSafe -func (colReplica *collectionReplicaImpl) getTSafe() tSafe { +func (colReplica *collectionReplicaImpl) getTSafe() *tSafe { return colReplica.tSafe } @@ -111,7 +111,6 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID) if col.ID() == collectionID { for _, p := range *col.Partitions() { for _, s := range *p.Segments() { - deleteSegment(colReplica.segments[s.ID()]) delete(colReplica.segments, s.ID()) } } @@ -203,7 +202,6 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID, for _, p := range *collection.Partitions() { if p.Tag() == partitionTag { for _, s := range *p.Segments() { - deleteSegment(colReplica.segments[s.ID()]) delete(colReplica.segments, s.ID()) } } else { diff --git a/internal/querynode/collection_replica_test.go b/internal/querynode/collection_replica_test.go index 083dd64845..31d26238cb 100644 --- a/internal/querynode/collection_replica_test.go +++ b/internal/querynode/collection_replica_test.go @@ -1,313 +1,1280 @@ package querynode import ( + "context" "testing" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) //----------------------------------------------------------------------------------------------------- collection func TestCollectionReplica_getCollectionNum(t *testing.T) { - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) - assert.Equal(t, node.replica.getCollectionNum(), 1) - node.Close() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + (*node.replica).freeAll() } func TestCollectionReplica_addCollection(t *testing.T) { - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) - node.Close() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + (*node.replica).freeAll() } func TestCollectionReplica_removeCollection(t *testing.T) { - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) - assert.Equal(t, node.replica.getCollectionNum(), 1) + ctx := context.Background() + node := NewQueryNode(ctx, 0) - err := node.replica.removeCollection(0) + collectionName := "collection0" + collectionID := UniqueID(0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) - assert.Equal(t, node.replica.getCollectionNum(), 0) - node.Close() + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).removeCollection(collectionID) + assert.NoError(t, err) + assert.Equal(t, (*node.replica).getCollectionNum(), 0) + + (*node.replica).freeAll() } func TestCollectionReplica_getCollectionByID(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) - targetCollection, err := node.replica.getCollectionByID(collectionID) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + targetCollection, err := (*node.replica).getCollectionByID(UniqueID(0)) assert.NoError(t, err) assert.NotNil(t, targetCollection) - assert.Equal(t, targetCollection.meta.Schema.Name, collectionName) - assert.Equal(t, targetCollection.meta.ID, collectionID) - node.Close() + assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") + assert.Equal(t, targetCollection.meta.ID, UniqueID(0)) + + (*node.replica).freeAll() } func TestCollectionReplica_getCollectionByName(t *testing.T) { - node := newQueryNode() - collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + ctx := context.Background() + node := NewQueryNode(ctx, 0) - targetCollection, err := node.replica.getCollectionByName(collectionName) + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + targetCollection, err := (*node.replica).getCollectionByName("collection0") assert.NoError(t, err) assert.NotNil(t, targetCollection) - assert.Equal(t, targetCollection.meta.Schema.Name, collectionName) - assert.Equal(t, targetCollection.meta.ID, collectionID) + assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") + assert.Equal(t, targetCollection.meta.ID, UniqueID(0)) - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_hasCollection(t *testing.T) { - node := newQueryNode() - collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + ctx := context.Background() + node := NewQueryNode(ctx, 0) - hasCollection := node.replica.hasCollection(collectionID) + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + hasCollection := (*node.replica).hasCollection(UniqueID(0)) assert.Equal(t, hasCollection, true) - hasCollection = node.replica.hasCollection(UniqueID(1)) + hasCollection = (*node.replica).hasCollection(UniqueID(1)) assert.Equal(t, hasCollection, false) - node.Close() + (*node.replica).freeAll() } //----------------------------------------------------------------------------------------------------- partition func TestCollectionReplica_getPartitionNum(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) - - partitionTags := []string{"a", "b", "c"} - for _, tag := range partitionTags { - err := node.replica.addPartition(collectionID, tag) - assert.NoError(t, err) - partition, err := node.replica.getPartitionByTag(collectionID, tag) - assert.NoError(t, err) - assert.Equal(t, partition.partitionTag, tag) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, } - partitionNum, err := node.replica.getPartitionNum(collectionID) + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) - assert.Equal(t, partitionNum, len(partitionTags)+1) // _default - node.Close() + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + for _, tag := range collectionMeta.PartitionTags { + err := (*node.replica).addPartition(collectionID, tag) + assert.NoError(t, err) + partition, err := (*node.replica).getPartitionByTag(collectionID, tag) + assert.NoError(t, err) + assert.Equal(t, partition.partitionTag, "default") + } + + partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0)) + assert.NoError(t, err) + assert.Equal(t, partitionNum, 1) + + (*node.replica).freeAll() } func TestCollectionReplica_addPartition(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) - - partitionTags := []string{"a", "b", "c"} - for _, tag := range partitionTags { - err := node.replica.addPartition(collectionID, tag) - assert.NoError(t, err) - partition, err := node.replica.getPartitionByTag(collectionID, tag) - assert.NoError(t, err) - assert.Equal(t, partition.partitionTag, tag) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, } - node.Close() + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + for _, tag := range collectionMeta.PartitionTags { + err := (*node.replica).addPartition(collectionID, tag) + assert.NoError(t, err) + partition, err := (*node.replica).getPartitionByTag(collectionID, tag) + assert.NoError(t, err) + assert.Equal(t, partition.partitionTag, "default") + } + + (*node.replica).freeAll() } func TestCollectionReplica_removePartition(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + partitionTag := "default" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - partitionTags := []string{"a", "b", "c"} + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } - for _, tag := range partitionTags { - err := node.replica.addPartition(collectionID, tag) + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{partitionTag}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + for _, tag := range collectionMeta.PartitionTags { + err := (*node.replica).addPartition(collectionID, tag) assert.NoError(t, err) - partition, err := node.replica.getPartitionByTag(collectionID, tag) + partition, err := (*node.replica).getPartitionByTag(collectionID, tag) assert.NoError(t, err) - assert.Equal(t, partition.partitionTag, tag) - err = node.replica.removePartition(collectionID, tag) + assert.Equal(t, partition.partitionTag, partitionTag) + err = (*node.replica).removePartition(collectionID, partitionTag) assert.NoError(t, err) } - node.Close() + + (*node.replica).freeAll() } func TestCollectionReplica_addPartitionsByCollectionMeta(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"p0"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) collectionMeta.PartitionTags = []string{"p0", "p1", "p2"} - err := node.replica.addPartitionsByCollectionMeta(collectionMeta) + err = (*node.replica).addPartitionsByCollectionMeta(&collectionMeta) assert.NoError(t, err) - partitionNum, err := node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) - assert.Equal(t, partitionNum, len(collectionMeta.PartitionTags)+1) - hasPartition := node.replica.hasPartition(UniqueID(0), "p0") + assert.Equal(t, partitionNum, 3) + hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, true) - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_removePartitionsByCollectionMeta(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"p0", "p1", "p2"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) collectionMeta.PartitionTags = []string{"p0"} - err := node.replica.addPartitionsByCollectionMeta(collectionMeta) + err = (*node.replica).addPartitionsByCollectionMeta(&collectionMeta) assert.NoError(t, err) - partitionNum, err := node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) - assert.Equal(t, partitionNum, len(collectionMeta.PartitionTags)+1) - - hasPartition := node.replica.hasPartition(UniqueID(0), "p0") + assert.Equal(t, partitionNum, 1) + hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, false) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, false) - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_getPartitionByTag(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - collectionMeta := genTestCollectionMeta(collectionName, collectionID) + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - err := node.replica.addPartition(collectionID, tag) + err := (*node.replica).addPartition(collectionID, tag) assert.NoError(t, err) - partition, err := node.replica.getPartitionByTag(collectionID, tag) + partition, err := (*node.replica).getPartitionByTag(collectionID, tag) assert.NoError(t, err) - assert.Equal(t, partition.partitionTag, tag) + assert.Equal(t, partition.partitionTag, "default") assert.NotNil(t, partition) } - node.Close() + + (*node.replica).freeAll() } func TestCollectionReplica_hasPartition(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - err := node.replica.addPartition(collectionID, collectionMeta.PartitionTags[0]) + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) - hasPartition := node.replica.hasPartition(collectionID, "default") + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + hasPartition := (*node.replica).hasPartition(UniqueID(0), "default") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(collectionID, "default1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "default1") assert.Equal(t, hasPartition, false) - node.Close() + + (*node.replica).freeAll() } //----------------------------------------------------------------------------------------------------- segment func TestCollectionReplica_addSegment(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) const segmentNum = 3 - tag := "default" for i := 0; i < segmentNum; i++ { - err := node.replica.addSegment(UniqueID(i), tag, collectionID) + err := (*node.replica).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) assert.NoError(t, err) - targetSeg, err := node.replica.getSegmentByID(UniqueID(i)) + targetSeg, err := (*node.replica).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) } - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_removeSegment(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) const segmentNum = 3 - tag := "default" - for i := 0; i < segmentNum; i++ { - err := node.replica.addSegment(UniqueID(i), tag, collectionID) + err := (*node.replica).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) assert.NoError(t, err) - targetSeg, err := node.replica.getSegmentByID(UniqueID(i)) + targetSeg, err := (*node.replica).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - err = node.replica.removeSegment(UniqueID(i)) + err = (*node.replica).removeSegment(UniqueID(i)) assert.NoError(t, err) } - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_getSegmentByID(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) const segmentNum = 3 - tag := "default" - for i := 0; i < segmentNum; i++ { - err := node.replica.addSegment(UniqueID(i), tag, collectionID) + err := (*node.replica).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) assert.NoError(t, err) - targetSeg, err := node.replica.getSegmentByID(UniqueID(i)) + targetSeg, err := (*node.replica).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) } - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_hasSegment(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) const segmentNum = 3 - tag := "default" - for i := 0; i < segmentNum; i++ { - err := node.replica.addSegment(UniqueID(i), tag, collectionID) + err := (*node.replica).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) assert.NoError(t, err) - targetSeg, err := node.replica.getSegmentByID(UniqueID(i)) + targetSeg, err := (*node.replica).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - hasSeg := node.replica.hasSegment(UniqueID(i)) + hasSeg := (*node.replica).hasSegment(UniqueID(i)) assert.Equal(t, hasSeg, true) - hasSeg = node.replica.hasSegment(UniqueID(i + 100)) + hasSeg = (*node.replica).hasSegment(UniqueID(i + 100)) assert.Equal(t, hasSeg, false) } - node.Close() + (*node.replica).freeAll() } func TestCollectionReplica_freeAll(t *testing.T) { - node := newQueryNode() + ctx := context.Background() + node := NewQueryNode(ctx, 0) + collectionName := "collection0" collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - node.Close() + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: collectionID, + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collectionID, collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + const segmentNum = 3 + for i := 0; i < segmentNum; i++ { + err := (*node.replica).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) + assert.NoError(t, err) + targetSeg, err := (*node.replica).getSegmentByID(UniqueID(i)) + assert.NoError(t, err) + assert.Equal(t, targetSeg.segmentID, UniqueID(i)) + hasSeg := (*node.replica).hasSegment(UniqueID(i)) + assert.Equal(t, hasSeg, true) + hasSeg = (*node.replica).hasSegment(UniqueID(i + 100)) + assert.Equal(t, hasSeg, false) + } + + (*node.replica).freeAll() } diff --git a/internal/querynode/collection_test.go b/internal/querynode/collection_test.go index 3f9717c2fb..fdce3e7565 100644 --- a/internal/querynode/collection_test.go +++ b/internal/querynode/collection_test.go @@ -1,48 +1,179 @@ package querynode import ( + "context" "testing" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) func TestCollection_Partitions(t *testing.T) { - node := newQueryNode() - collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + ctx := context.Background() + node := NewQueryNode(ctx, 0) - collection, err := node.replica.getCollectionByName(collectionName) + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + for _, tag := range collectionMeta.PartitionTags { + err := (*node.replica).addPartition(collection.ID(), tag) + assert.NoError(t, err) + } + partitions := collection.Partitions() - assert.Equal(t, 1, len(*partitions)) + assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions)) } func TestCollection_newCollection(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) } func TestCollection_deleteCollection(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + deleteCollection(collection) } diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index 98bbab8626..e865d24b04 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -11,10 +11,10 @@ type dataSyncService struct { ctx context.Context fg *flowgraph.TimeTickedFlowGraph - replica collectionReplica + replica *collectionReplica } -func newDataSyncService(ctx context.Context, replica collectionReplica) *dataSyncService { +func newDataSyncService(ctx context.Context, replica *collectionReplica) *dataSyncService { return &dataSyncService{ ctx: ctx, diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index 1c07e396b3..4b1e8f147b 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -1,22 +1,101 @@ package querynode import ( + "context" "encoding/binary" "math" "testing" + "time" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) // NOTE: start pulsar before test func TestDataSyncService_Start(t *testing.T) { + Params.Init() + var ctx context.Context + + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init query node + pulsarURL, _ := Params.pulsarAddress() + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) // test data generate const msgLength = 10 const DIM = 16 @@ -100,25 +179,25 @@ func TestDataSyncService_Start(t *testing.T) { // pulsar produce const receiveBufSize = 1024 producerChannels := Params.insertChannelNames() - pulsarURL, _ := Params.pulsarAddress() - insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) insertStream.SetPulsarClient(pulsarURL) insertStream.CreatePulsarProducers(producerChannels) var insertMsgStream msgstream.MsgStream = insertStream insertMsgStream.Start() - err := insertMsgStream.Produce(&msgPack) + err = insertMsgStream.Produce(&msgPack) assert.NoError(t, err) err = insertMsgStream.Broadcast(&timeTickMsgPack) assert.NoError(t, err) // dataSync - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica) + node.dataSyncService = newDataSyncService(node.ctx, node.replica) go node.dataSyncService.start() node.Close() + <-ctx.Done() } diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index b75d214252..8508cccfb1 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -10,7 +10,7 @@ import ( type insertNode struct { BaseNode - replica collectionReplica + replica *collectionReplica } type InsertData struct { @@ -58,13 +58,13 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...) // check if segment exists, if not, create this segment - if !iNode.replica.hasSegment(task.SegmentID) { - collection, err := iNode.replica.getCollectionByName(task.CollectionName) + if !(*iNode.replica).hasSegment(task.SegmentID) { + collection, err := (*iNode.replica).getCollectionByName(task.CollectionName) if err != nil { log.Println(err) continue } - err = iNode.replica.addSegment(task.SegmentID, task.PartitionTag, collection.ID()) + err = (*iNode.replica).addSegment(task.SegmentID, task.PartitionTag, collection.ID()) if err != nil { log.Println(err) continue @@ -74,7 +74,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { // 2. do preInsert for segmentID := range insertData.insertRecords { - var targetSegment, err = iNode.replica.getSegmentByID(segmentID) + var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID) if err != nil { log.Println("preInsert failed") // TODO: add error handling @@ -102,7 +102,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { } func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) { - var targetSegment, err = iNode.replica.getSegmentByID(segmentID) + var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID) if err != nil { log.Println("cannot find segment:", segmentID) // TODO: add error handling @@ -127,7 +127,7 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn wg.Done() } -func newInsertNode(replica collectionReplica) *insertNode { +func newInsertNode(replica *collectionReplica) *insertNode { maxQueueLength := Params.flowGraphMaxQueueLength() maxParallelism := Params.flowGraphMaxParallelism() diff --git a/internal/querynode/flow_graph_service_time_node.go b/internal/querynode/flow_graph_service_time_node.go index 761ad9e52e..e7dfe1a89c 100644 --- a/internal/querynode/flow_graph_service_time_node.go +++ b/internal/querynode/flow_graph_service_time_node.go @@ -6,7 +6,7 @@ import ( type serviceTimeNode struct { BaseNode - replica collectionReplica + replica *collectionReplica } func (stNode *serviceTimeNode) Name() string { @@ -28,12 +28,12 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg { } // update service time - stNode.replica.getTSafe().set(serviceTimeMsg.timeRange.timestampMax) + (*(*stNode.replica).getTSafe()).set(serviceTimeMsg.timeRange.timestampMax) //fmt.Println("update tSafe to:", getPhysicalTime(serviceTimeMsg.timeRange.timestampMax)) return nil } -func newServiceTimeNode(replica collectionReplica) *serviceTimeNode { +func newServiceTimeNode(replica *collectionReplica) *serviceTimeNode { maxQueueLength := Params.flowGraphMaxQueueLength() maxParallelism := Params.flowGraphMaxParallelism() diff --git a/internal/querynode/meta_service.go b/internal/querynode/meta_service.go index 79fa37bf6e..492388f4f6 100644 --- a/internal/querynode/meta_service.go +++ b/internal/querynode/meta_service.go @@ -26,10 +26,10 @@ const ( type metaService struct { ctx context.Context kvBase *etcdkv.EtcdKV - replica collectionReplica + replica *collectionReplica } -func newMetaService(ctx context.Context, replica collectionReplica) *metaService { +func newMetaService(ctx context.Context, replica *collectionReplica) *metaService { ETCDAddr := Params.etcdAddress() MetaRootPath := Params.metaRootPath() @@ -149,12 +149,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) { col := mService.collectionUnmarshal(value) if col != nil { - err := mService.replica.addCollection(col, value) + err := (*mService.replica).addCollection(col, value) if err != nil { log.Println(err) } for _, partitionTag := range col.PartitionTags { - err = mService.replica.addPartition(col.ID, partitionTag) + err = (*mService.replica).addPartition(col.ID, partitionTag) if err != nil { log.Println(err) } @@ -173,7 +173,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) { // TODO: what if seg == nil? We need to notify master and return rpc request failed if seg != nil { - err := mService.replica.addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID) + err := (*mService.replica).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID) if err != nil { log.Println(err) return @@ -202,7 +202,7 @@ func (mService *metaService) processSegmentModify(id string, value string) { } if seg != nil { - targetSegment, err := mService.replica.getSegmentByID(seg.SegmentID) + targetSegment, err := (*mService.replica).getSegmentByID(seg.SegmentID) if err != nil { log.Println(err) return @@ -218,11 +218,11 @@ func (mService *metaService) processCollectionModify(id string, value string) { col := mService.collectionUnmarshal(value) if col != nil { - err := mService.replica.addPartitionsByCollectionMeta(col) + err := (*mService.replica).addPartitionsByCollectionMeta(col) if err != nil { log.Println(err) } - err = mService.replica.removePartitionsByCollectionMeta(col) + err = (*mService.replica).removePartitionsByCollectionMeta(col) if err != nil { log.Println(err) } @@ -249,7 +249,7 @@ func (mService *metaService) processSegmentDelete(id string) { log.Println("Cannot parse segment id:" + id) } - err = mService.replica.removeSegment(segmentID) + err = (*mService.replica).removeSegment(segmentID) if err != nil { log.Println(err) return @@ -264,7 +264,7 @@ func (mService *metaService) processCollectionDelete(id string) { log.Println("Cannot parse collection id:" + id) } - err = mService.replica.removeCollection(collectionID) + err = (*mService.replica).removeCollection(collectionID) if err != nil { log.Println(err) return diff --git a/internal/querynode/meta_service_test.go b/internal/querynode/meta_service_test.go index e8207d08c3..bdc2f8835c 100644 --- a/internal/querynode/meta_service_test.go +++ b/internal/querynode/meta_service_test.go @@ -3,13 +3,23 @@ package querynode import ( "context" "math" + "os" "testing" "time" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) +func TestMain(m *testing.M) { + Params.Init() + exitCode := m.Run() + os.Exit(exitCode) +} + func TestMetaService_start(t *testing.T) { var ctx context.Context @@ -27,7 +37,6 @@ func TestMetaService_start(t *testing.T) { node.metaService = newMetaService(ctx, node.replica) (*node.metaService).start() - node.Close() } func TestMetaService_getCollectionObjId(t *testing.T) { @@ -110,9 +119,47 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T) func TestMetaService_printCollectionStruct(t *testing.T) { collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - printCollectionStruct(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + printCollectionStruct(&collectionMeta) } func TestMetaService_printSegmentStruct(t *testing.T) { @@ -131,8 +178,13 @@ func TestMetaService_printSegmentStruct(t *testing.T) { } func TestMetaService_processCollectionCreate(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) id := "0" value := `schema: < @@ -144,10 +196,6 @@ func TestMetaService_processCollectionCreate(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -164,21 +212,71 @@ func TestMetaService_processCollectionCreate(t *testing.T) { node.metaService.processCollectionCreate(id, value) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) - node.Close() } func TestMetaService_processSegmentCreate(t *testing.T) { - node := newQueryNode() + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) + collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + colMetaBlob := proto.MarshalTextString(&collectionMeta) + + err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob)) + assert.NoError(t, err) + + err = (*node.replica).addPartition(UniqueID(0), "default") + assert.NoError(t, err) id := "0" value := `partition_tag: "default" @@ -189,15 +287,19 @@ func TestMetaService_processSegmentCreate(t *testing.T) { (*node.metaService).processSegmentCreate(id, value) - s, err := node.replica.getSegmentByID(UniqueID(0)) + s, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, s.segmentID, UniqueID(0)) - node.Close() } func TestMetaService_processCreate(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) key1 := "by-dev/meta/collection/0" msg1 := `schema: < @@ -209,10 +311,6 @@ func TestMetaService_processCreate(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -228,10 +326,10 @@ func TestMetaService_processCreate(t *testing.T) { ` (*node.metaService).processCreate(key1, msg1) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) @@ -243,19 +341,68 @@ func TestMetaService_processCreate(t *testing.T) { ` (*node.metaService).processCreate(key2, msg2) - s, err := node.replica.getSegmentByID(UniqueID(0)) + s, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, s.segmentID, UniqueID(0)) - node.Close() } func TestMetaService_processSegmentModify(t *testing.T) { - node := newQueryNode() + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) + collectionName := "collection0" - collectionID := UniqueID(0) - segmentID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, segmentID) - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + colMetaBlob := proto.MarshalTextString(&collectionMeta) + + err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob)) + assert.NoError(t, err) + + err = (*node.replica).addPartition(UniqueID(0), "default") + assert.NoError(t, err) id := "0" value := `partition_tag: "default" @@ -265,9 +412,9 @@ func TestMetaService_processSegmentModify(t *testing.T) { ` (*node.metaService).processSegmentCreate(id, value) - s, err := node.replica.getSegmentByID(segmentID) + s, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) - assert.Equal(t, s.segmentID, segmentID) + assert.Equal(t, s.segmentID, UniqueID(0)) newValue := `partition_tag: "default" channel_start: 0 @@ -277,15 +424,19 @@ func TestMetaService_processSegmentModify(t *testing.T) { // TODO: modify segment for testing processCollectionModify (*node.metaService).processSegmentModify(id, newValue) - seg, err := node.replica.getSegmentByID(segmentID) + seg, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) - assert.Equal(t, seg.segmentID, segmentID) - node.Close() + assert.Equal(t, seg.segmentID, UniqueID(0)) } func TestMetaService_processCollectionModify(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) id := "0" value := `schema: < @@ -297,10 +448,6 @@ func TestMetaService_processCollectionModify(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -318,24 +465,24 @@ func TestMetaService_processCollectionModify(t *testing.T) { ` (*node.metaService).processCollectionCreate(id, value) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) - partitionNum, err := node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, partitionNum, 3) - hasPartition := node.replica.hasPartition(UniqueID(0), "p0") + hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p3") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3") assert.Equal(t, hasPartition, false) newValue := `schema: < @@ -347,10 +494,6 @@ func TestMetaService_processCollectionModify(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -368,28 +511,32 @@ func TestMetaService_processCollectionModify(t *testing.T) { ` (*node.metaService).processCollectionModify(id, newValue) - collection, err = node.replica.getCollectionByName("test") + collection, err = (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) - partitionNum, err = node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, partitionNum, 3) - hasPartition = node.replica.hasPartition(UniqueID(0), "p0") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, false) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p3") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3") assert.Equal(t, hasPartition, true) - node.Close() } func TestMetaService_processModify(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) key1 := "by-dev/meta/collection/0" msg1 := `schema: < @@ -401,10 +548,6 @@ func TestMetaService_processModify(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -422,24 +565,24 @@ func TestMetaService_processModify(t *testing.T) { ` (*node.metaService).processCreate(key1, msg1) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) - partitionNum, err := node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, partitionNum, 3) - hasPartition := node.replica.hasPartition(UniqueID(0), "p0") + hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p3") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3") assert.Equal(t, hasPartition, false) key2 := "by-dev/meta/segment/0" @@ -450,7 +593,7 @@ func TestMetaService_processModify(t *testing.T) { ` (*node.metaService).processCreate(key2, msg2) - s, err := node.replica.getSegmentByID(UniqueID(0)) + s, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, s.segmentID, UniqueID(0)) @@ -465,10 +608,6 @@ func TestMetaService_processModify(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -486,21 +625,21 @@ func TestMetaService_processModify(t *testing.T) { ` (*node.metaService).processModify(key1, msg3) - collection, err = node.replica.getCollectionByName("test") + collection, err = (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) - partitionNum, err = node.replica.getPartitionNum(UniqueID(0)) + partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, partitionNum, 3) - hasPartition = node.replica.hasPartition(UniqueID(0), "p0") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0") assert.Equal(t, hasPartition, false) - hasPartition = node.replica.hasPartition(UniqueID(0), "p1") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p2") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2") assert.Equal(t, hasPartition, true) - hasPartition = node.replica.hasPartition(UniqueID(0), "p3") + hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3") assert.Equal(t, hasPartition, true) msg4 := `partition_tag: "p1" @@ -510,18 +649,68 @@ func TestMetaService_processModify(t *testing.T) { ` (*node.metaService).processModify(key2, msg4) - seg, err := node.replica.getSegmentByID(UniqueID(0)) + seg, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, seg.segmentID, UniqueID(0)) - node.Close() } func TestMetaService_processSegmentDelete(t *testing.T) { - node := newQueryNode() + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) + collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + colMetaBlob := proto.MarshalTextString(&collectionMeta) + + err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob)) + assert.NoError(t, err) + + err = (*node.replica).addPartition(UniqueID(0), "default") + assert.NoError(t, err) id := "0" value := `partition_tag: "default" @@ -531,19 +720,23 @@ func TestMetaService_processSegmentDelete(t *testing.T) { ` (*node.metaService).processSegmentCreate(id, value) - seg, err := node.replica.getSegmentByID(UniqueID(0)) + seg, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, seg.segmentID, UniqueID(0)) (*node.metaService).processSegmentDelete("0") - mapSize := node.replica.getSegmentNum() + mapSize := (*node.replica).getSegmentNum() assert.Equal(t, mapSize, 0) - node.Close() } func TestMetaService_processCollectionDelete(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) id := "0" value := `schema: < @@ -555,10 +748,6 @@ func TestMetaService_processCollectionDelete(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -574,22 +763,26 @@ func TestMetaService_processCollectionDelete(t *testing.T) { ` (*node.metaService).processCollectionCreate(id, value) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) (*node.metaService).processCollectionDelete(id) - collectionNum = node.replica.getCollectionNum() + collectionNum = (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 0) - node.Close() } func TestMetaService_processDelete(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel := context.WithDeadline(context.Background(), d) + defer cancel() + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) key1 := "by-dev/meta/collection/0" msg1 := `schema: < @@ -601,10 +794,6 @@ func TestMetaService_processDelete(t *testing.T) { key: "dim" value: "16" > - index_params: < - key: "metric_type" - value: "L2" - > > fields: < name: "age" @@ -620,10 +809,10 @@ func TestMetaService_processDelete(t *testing.T) { ` (*node.metaService).processCreate(key1, msg1) - collectionNum := node.replica.getCollectionNum() + collectionNum := (*node.replica).getCollectionNum() assert.Equal(t, collectionNum, 1) - collection, err := node.replica.getCollectionByName("test") + collection, err := (*node.replica).getCollectionByName("test") assert.NoError(t, err) assert.Equal(t, collection.ID(), UniqueID(0)) @@ -635,48 +824,77 @@ func TestMetaService_processDelete(t *testing.T) { ` (*node.metaService).processCreate(key2, msg2) - seg, err := node.replica.getSegmentByID(UniqueID(0)) + seg, err := (*node.replica).getSegmentByID(UniqueID(0)) assert.NoError(t, err) assert.Equal(t, seg.segmentID, UniqueID(0)) (*node.metaService).processDelete(key1) - collectionsSize := node.replica.getCollectionNum() + collectionsSize := (*node.replica).getCollectionNum() assert.Equal(t, collectionsSize, 0) - mapSize := node.replica.getSegmentNum() + mapSize := (*node.replica).getSegmentNum() assert.Equal(t, mapSize, 0) - node.Close() } func TestMetaService_processResp(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + var ctx context.Context + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) metaChan := (*node.metaService).kvBase.WatchWithPrefix("") select { - case <-node.queryNodeLoopCtx.Done(): + case <-node.ctx.Done(): return case resp := <-metaChan: _ = (*node.metaService).processResp(resp) } - node.Close() } func TestMetaService_loadCollections(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + var ctx context.Context + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) err2 := (*node.metaService).loadCollections() assert.Nil(t, err2) - node.Close() } func TestMetaService_loadSegments(t *testing.T) { - node := newQueryNode() - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) + var ctx context.Context + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init metaService + node := NewQueryNode(ctx, 0) + node.metaService = newMetaService(ctx, node.replica) err2 := (*node.metaService).loadSegments() assert.Nil(t, err2) - node.Close() } diff --git a/internal/querynode/param_table.go b/internal/querynode/param_table.go index 6054a30bb8..026d116149 100644 --- a/internal/querynode/param_table.go +++ b/internal/querynode/param_table.go @@ -2,8 +2,8 @@ package querynode import ( "log" - "os" "strconv" + "strings" "github.com/zilliztech/milvus-distributed/internal/util/paramtable" ) @@ -21,16 +21,15 @@ func (p *ParamTable) Init() { panic(err) } - queryNodeIDStr := os.Getenv("QUERY_NODE_ID") - if queryNodeIDStr == "" { - queryNodeIDList := p.QueryNodeIDList() - if len(queryNodeIDList) <= 0 { - queryNodeIDStr = "0" - } else { - queryNodeIDStr = strconv.Itoa(int(queryNodeIDList[0])) - } + err = p.LoadYaml("milvus.yaml") + if err != nil { + panic(err) + } + + err = p.LoadYaml("advanced/channel.yaml") + if err != nil { + panic(err) } - p.Save("_queryNodeID", queryNodeIDStr) } func (p *ParamTable) pulsarAddress() (string, error) { @@ -41,8 +40,8 @@ func (p *ParamTable) pulsarAddress() (string, error) { return url, nil } -func (p *ParamTable) QueryNodeID() UniqueID { - queryNodeID, err := p.Load("_queryNodeID") +func (p *ParamTable) queryNodeID() int { + queryNodeID, err := p.Load("reader.clientid") if err != nil { panic(err) } @@ -50,7 +49,7 @@ func (p *ParamTable) QueryNodeID() UniqueID { if err != nil { panic(err) } - return UniqueID(id) + return id } func (p *ParamTable) insertChannelRange() []int { @@ -58,47 +57,138 @@ func (p *ParamTable) insertChannelRange() []int { if err != nil { panic(err) } - return paramtable.ConvertRangeToIntRange(insertChannelRange, ",") + + channelRange := strings.Split(insertChannelRange, ",") + if len(channelRange) != 2 { + panic("Illegal channel range num") + } + channelBegin, err := strconv.Atoi(channelRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(channelRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + return []int{channelBegin, channelEnd} } // advanced params // stats func (p *ParamTable) statsPublishInterval() int { - return p.ParseInt("queryNode.stats.publishInterval") + timeInterval, err := p.Load("queryNode.stats.publishInterval") + if err != nil { + panic(err) + } + interval, err := strconv.Atoi(timeInterval) + if err != nil { + panic(err) + } + return interval } // dataSync: func (p *ParamTable) flowGraphMaxQueueLength() int32 { - return p.ParseInt32("queryNode.dataSync.flowGraph.maxQueueLength") + queueLength, err := p.Load("queryNode.dataSync.flowGraph.maxQueueLength") + if err != nil { + panic(err) + } + length, err := strconv.Atoi(queueLength) + if err != nil { + panic(err) + } + return int32(length) } func (p *ParamTable) flowGraphMaxParallelism() int32 { - return p.ParseInt32("queryNode.dataSync.flowGraph.maxParallelism") + maxParallelism, err := p.Load("queryNode.dataSync.flowGraph.maxParallelism") + if err != nil { + panic(err) + } + maxPara, err := strconv.Atoi(maxParallelism) + if err != nil { + panic(err) + } + return int32(maxPara) } // msgStream func (p *ParamTable) insertReceiveBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.insert.recvBufSize") + revBufSize, err := p.Load("queryNode.msgStream.insert.recvBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(revBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) insertPulsarBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.insert.pulsarBufSize") + pulsarBufSize, err := p.Load("queryNode.msgStream.insert.pulsarBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(pulsarBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) searchReceiveBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.search.recvBufSize") + revBufSize, err := p.Load("queryNode.msgStream.search.recvBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(revBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) searchPulsarBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.search.pulsarBufSize") + pulsarBufSize, err := p.Load("queryNode.msgStream.search.pulsarBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(pulsarBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) searchResultReceiveBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.searchResult.recvBufSize") + revBufSize, err := p.Load("queryNode.msgStream.searchResult.recvBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(revBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) statsReceiveBufSize() int64 { - return p.ParseInt64("queryNode.msgStream.stats.recvBufSize") + revBufSize, err := p.Load("queryNode.msgStream.stats.recvBufSize") + if err != nil { + panic(err) + } + bufSize, err := strconv.Atoi(revBufSize) + if err != nil { + panic(err) + } + return int64(bufSize) } func (p *ParamTable) etcdAddress() string { @@ -122,73 +212,123 @@ func (p *ParamTable) metaRootPath() string { } func (p *ParamTable) gracefulTime() int64 { - return p.ParseInt64("queryNode.gracefulTime") + gracefulTime, err := p.Load("queryNode.gracefulTime") + if err != nil { + panic(err) + } + time, err := strconv.Atoi(gracefulTime) + if err != nil { + panic(err) + } + return int64(time) } func (p *ParamTable) insertChannelNames() []string { - - prefix, err := p.Load("msgChannel.chanNamePrefix.insert") + ch, err := p.Load("msgChannel.chanNamePrefix.insert") if err != nil { log.Fatal(err) } - prefix += "-" channelRange, err := p.Load("msgChannel.channelRange.insert") if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",") - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") } - sep := len(channelIDs) / p.queryNodeNum() - index := p.sliceIndex() - if index == -1 { - panic("queryNodeID not Match with Config") + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) } - start := index * sep - return ret[start : start+sep] + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels } func (p *ParamTable) searchChannelNames() []string { - prefix, err := p.Load("msgChannel.chanNamePrefix.search") + ch, err := p.Load("msgChannel.chanNamePrefix.search") if err != nil { log.Fatal(err) } - prefix += "-" channelRange, err := p.Load("msgChannel.channelRange.search") if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",") - - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") } - return ret + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels } func (p *ParamTable) searchResultChannelNames() []string { - prefix, err := p.Load("msgChannel.chanNamePrefix.searchResult") + ch, err := p.Load("msgChannel.chanNamePrefix.searchResult") if err != nil { log.Fatal(err) } - prefix += "-" channelRange, err := p.Load("msgChannel.channelRange.searchResult") if err != nil { panic(err) } - channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",") - - var ret []string - for _, ID := range channelIDs { - ret = append(ret, prefix+strconv.Itoa(ID)) + chanRange := strings.Split(channelRange, ",") + if len(chanRange) != 2 { + panic("Illegal channel range num") } - return ret + channelBegin, err := strconv.Atoi(chanRange[0]) + if err != nil { + panic(err) + } + channelEnd, err := strconv.Atoi(chanRange[1]) + if err != nil { + panic(err) + } + if channelBegin < 0 || channelEnd < 0 { + panic("Illegal channel range value") + } + if channelBegin > channelEnd { + panic("Illegal channel range value") + } + + channels := make([]string, channelEnd-channelBegin) + for i := 0; i < channelEnd-channelBegin; i++ { + channels[i] = ch + "-" + strconv.Itoa(channelBegin+i) + } + return channels } func (p *ParamTable) msgChannelSubName() string { @@ -197,11 +337,7 @@ func (p *ParamTable) msgChannelSubName() string { if err != nil { log.Panic(err) } - queryNodeIDStr, err := p.Load("_QueryNodeID") - if err != nil { - panic(err) - } - return name + "-" + queryNodeIDStr + return name } func (p *ParamTable) statsChannelName() string { @@ -211,18 +347,3 @@ func (p *ParamTable) statsChannelName() string { } return channels } - -func (p *ParamTable) sliceIndex() int { - queryNodeID := p.QueryNodeID() - queryNodeIDList := p.QueryNodeIDList() - for i := 0; i < len(queryNodeIDList); i++ { - if queryNodeID == queryNodeIDList[i] { - return i - } - } - return -1 -} - -func (p *ParamTable) queryNodeNum() int { - return len(p.QueryNodeIDList()) -} diff --git a/internal/querynode/param_table_test.go b/internal/querynode/param_table_test.go index fd609deb98..544659134d 100644 --- a/internal/querynode/param_table_test.go +++ b/internal/querynode/param_table_test.go @@ -1,109 +1,128 @@ package querynode import ( - "fmt" "strings" "testing" "github.com/stretchr/testify/assert" ) +func TestParamTable_Init(t *testing.T) { + Params.Init() +} + func TestParamTable_PulsarAddress(t *testing.T) { + Params.Init() address, err := Params.pulsarAddress() assert.NoError(t, err) split := strings.Split(address, ":") - assert.Equal(t, "pulsar", split[0]) - assert.Equal(t, "6650", split[len(split)-1]) + assert.Equal(t, split[0], "pulsar") + assert.Equal(t, split[len(split)-1], "6650") } func TestParamTable_QueryNodeID(t *testing.T) { - id := Params.QueryNodeID() - assert.Contains(t, Params.QueryNodeIDList(), id) + Params.Init() + id := Params.queryNodeID() + assert.Equal(t, id, 0) } func TestParamTable_insertChannelRange(t *testing.T) { + Params.Init() channelRange := Params.insertChannelRange() - assert.Equal(t, 2, len(channelRange)) + assert.Equal(t, len(channelRange), 2) + assert.Equal(t, channelRange[0], 0) + assert.Equal(t, channelRange[1], 1) } func TestParamTable_statsServiceTimeInterval(t *testing.T) { + Params.Init() interval := Params.statsPublishInterval() - assert.Equal(t, 1000, interval) + assert.Equal(t, interval, 1000) } func TestParamTable_statsMsgStreamReceiveBufSize(t *testing.T) { + Params.Init() bufSize := Params.statsReceiveBufSize() - assert.Equal(t, int64(64), bufSize) + assert.Equal(t, bufSize, int64(64)) } func TestParamTable_insertMsgStreamReceiveBufSize(t *testing.T) { + Params.Init() bufSize := Params.insertReceiveBufSize() - assert.Equal(t, int64(1024), bufSize) + assert.Equal(t, bufSize, int64(1024)) } func TestParamTable_searchMsgStreamReceiveBufSize(t *testing.T) { + Params.Init() bufSize := Params.searchReceiveBufSize() - assert.Equal(t, int64(512), bufSize) + assert.Equal(t, bufSize, int64(512)) } func TestParamTable_searchResultMsgStreamReceiveBufSize(t *testing.T) { + Params.Init() bufSize := Params.searchResultReceiveBufSize() - assert.Equal(t, int64(64), bufSize) + assert.Equal(t, bufSize, int64(64)) } func TestParamTable_searchPulsarBufSize(t *testing.T) { + Params.Init() bufSize := Params.searchPulsarBufSize() - assert.Equal(t, int64(512), bufSize) + assert.Equal(t, bufSize, int64(512)) } func TestParamTable_insertPulsarBufSize(t *testing.T) { + Params.Init() bufSize := Params.insertPulsarBufSize() - assert.Equal(t, int64(1024), bufSize) + assert.Equal(t, bufSize, int64(1024)) } func TestParamTable_flowGraphMaxQueueLength(t *testing.T) { + Params.Init() length := Params.flowGraphMaxQueueLength() - assert.Equal(t, int32(1024), length) + assert.Equal(t, length, int32(1024)) } func TestParamTable_flowGraphMaxParallelism(t *testing.T) { + Params.Init() maxParallelism := Params.flowGraphMaxParallelism() - assert.Equal(t, int32(1024), maxParallelism) + assert.Equal(t, maxParallelism, int32(1024)) } func TestParamTable_insertChannelNames(t *testing.T) { + Params.Init() names := Params.insertChannelNames() - channelRange := Params.insertChannelRange() - num := channelRange[1] - channelRange[0] - num = num / Params.queryNodeNum() - assert.Equal(t, num, len(names)) - start := num * Params.sliceIndex() - assert.Equal(t, fmt.Sprintf("insert-%d", channelRange[start]), names[0]) + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "insert-0") } func TestParamTable_searchChannelNames(t *testing.T) { + Params.Init() names := Params.searchChannelNames() assert.Equal(t, len(names), 1) - assert.Equal(t, "search-0", names[0]) + assert.Equal(t, names[0], "search-0") } func TestParamTable_searchResultChannelNames(t *testing.T) { + Params.Init() names := Params.searchResultChannelNames() - assert.NotNil(t, names) + assert.Equal(t, len(names), 1) + assert.Equal(t, names[0], "searchResult-0") } func TestParamTable_msgChannelSubName(t *testing.T) { + Params.Init() name := Params.msgChannelSubName() - expectName := fmt.Sprintf("queryNode-%d", Params.QueryNodeID()) - assert.Equal(t, expectName, name) + assert.Equal(t, name, "queryNode") } func TestParamTable_statsChannelName(t *testing.T) { + Params.Init() name := Params.statsChannelName() - assert.Equal(t, "query-node-stats", name) + assert.Equal(t, name, "query-node-stats") } func TestParamTable_metaRootPath(t *testing.T) { + Params.Init() path := Params.metaRootPath() - assert.Equal(t, "by-dev/meta", path) + assert.Equal(t, path, "by-dev/meta") } diff --git a/internal/querynode/partition_test.go b/internal/querynode/partition_test.go index 512117ff35..f9aa7d324c 100644 --- a/internal/querynode/partition_test.go +++ b/internal/querynode/partition_test.go @@ -1,20 +1,77 @@ package querynode import ( + "context" "testing" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) func TestPartition_Segments(t *testing.T) { - node := newQueryNode() - collectionName := "collection0" - collectionID := UniqueID(0) - initTestMeta(t, node, collectionName, collectionID, 0) + ctx := context.Background() + node := NewQueryNode(ctx, 0) - collection, err := node.replica.getCollectionByName(collectionName) + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) - collectionMeta := collection.meta + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + for _, tag := range collectionMeta.PartitionTags { + err := (*node.replica).addPartition(collection.ID(), tag) + assert.NoError(t, err) + } partitions := collection.Partitions() assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions)) @@ -23,12 +80,12 @@ func TestPartition_Segments(t *testing.T) { const segmentNum = 3 for i := 0; i < segmentNum; i++ { - err := node.replica.addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID()) + err := (*node.replica).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID()) assert.NoError(t, err) } segments := targetPartition.Segments() - assert.Equal(t, segmentNum+1, len(*segments)) + assert.Equal(t, segmentNum, len(*segments)) } func TestPartition_newPartition(t *testing.T) { diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 611ff4eafc..0d26f90b9d 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -8,17 +8,59 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) func TestPlan_Plan(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" @@ -32,13 +74,52 @@ func TestPlan_Plan(t *testing.T) { } func TestPlan_PlaceholderGroup(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index c7c5f4bec8..f25c2b12d5 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -17,12 +17,11 @@ import ( ) type QueryNode struct { - queryNodeLoopCtx context.Context - queryNodeLoopCancel func() + ctx context.Context QueryNodeID uint64 - replica collectionReplica + replica *collectionReplica dataSyncService *dataSyncService metaService *metaService @@ -30,14 +29,7 @@ type QueryNode struct { statsService *statsService } -func Init() { - Params.Init() -} - func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode { - - ctx1, cancel := context.WithCancel(ctx) - segmentsMap := make(map[int64]*Segment) collections := make([]*Collection, 0) @@ -51,11 +43,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode { } return &QueryNode{ - queryNodeLoopCtx: ctx1, - queryNodeLoopCancel: cancel, - QueryNodeID: queryNodeID, + ctx: ctx, - replica: replica, + QueryNodeID: queryNodeID, + + replica: &replica, dataSyncService: nil, metaService: nil, @@ -64,34 +56,31 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode { } } -func (node *QueryNode) Start() error { - // todo add connectMaster logic - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica) - node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica) - node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) - node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica) +func (node *QueryNode) Start() { + node.dataSyncService = newDataSyncService(node.ctx, node.replica) + node.searchService = newSearchService(node.ctx, node.replica) + node.metaService = newMetaService(node.ctx, node.replica) + node.statsService = newStatsService(node.ctx, node.replica) go node.dataSyncService.start() go node.searchService.start() go node.metaService.start() - go node.statsService.start() - return nil + node.statsService.start() } func (node *QueryNode) Close() { - node.queryNodeLoopCancel() - + <-node.ctx.Done() // free collectionReplica - node.replica.freeAll() + (*node.replica).freeAll() // close services if node.dataSyncService != nil { - node.dataSyncService.close() + (*node.dataSyncService).close() } if node.searchService != nil { - node.searchService.close() + (*node.searchService).close() } if node.statsService != nil { - node.statsService.close() + (*node.statsService).close() } } diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 90065721e1..9cf6929fcd 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -2,93 +2,18 @@ package querynode import ( "context" - "os" "testing" "time" - - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) const ctxTimeInMillisecond = 200 const closeWithDeadline = true -func setup() { +// NOTE: start pulsar and etcd before test +func TestQueryNode_start(t *testing.T) { Params.Init() -} - -func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta { - fieldVec := schemapb.FieldSchema{ - Name: "vec", - IsPrimaryKey: false, - DataType: schemapb.DataType_VECTOR_FLOAT, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "metric_type", - Value: "L2", - }, - }, - } - - fieldInt := schemapb.FieldSchema{ - Name: "age", - IsPrimaryKey: false, - DataType: schemapb.DataType_INT32, - } - - schema := schemapb.CollectionSchema{ - Name: collectionName, - AutoID: true, - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - ID: collectionID, - Schema: &schema, - CreateTime: Timestamp(0), - SegmentIDs: []UniqueID{0}, - PartitionTags: []string{"default"}, - } - return &collectionMeta -} - -func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) { - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - - collectionMetaBlob := proto.MarshalTextString(collectionMeta) - assert.NotEqual(t, "", collectionMetaBlob) - - var err = node.replica.addCollection(collectionMeta, collectionMetaBlob) - assert.NoError(t, err) - - collection, err := node.replica.getCollectionByName(collectionName) - assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) - assert.Equal(t, node.replica.getCollectionNum(), 1) - - err = node.replica.addPartition(collection.ID(), collectionMeta.PartitionTags[0]) - assert.NoError(t, err) - - err = node.replica.addSegment(segmentID, collectionMeta.PartitionTags[0], collectionID) - assert.NoError(t, err) -} - -func newQueryNode() *QueryNode { var ctx context.Context - if closeWithDeadline { var cancel context.CancelFunc d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) @@ -98,21 +23,7 @@ func newQueryNode() *QueryNode { ctx = context.Background() } - svr := NewQueryNode(ctx, 0) - return svr - -} - -func TestMain(m *testing.M) { - setup() - exitCode := m.Run() - os.Exit(exitCode) -} - -// NOTE: start pulsar and etcd before test -func TestQueryNode_Start(t *testing.T) { - localNode := newQueryNode() - err := localNode.Start() - assert.Nil(t, err) - localNode.Close() + node := NewQueryNode(ctx, 0) + node.Start() + node.Close() } diff --git a/internal/querynode/reader.go b/internal/querynode/reader.go new file mode 100644 index 0000000000..feb5f73fd9 --- /dev/null +++ b/internal/querynode/reader.go @@ -0,0 +1,15 @@ +package querynode + +import ( + "context" +) + +func Init() { + Params.Init() +} + +func StartQueryNode(ctx context.Context) { + node := NewQueryNode(ctx, 0) + + node.Start() +} diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index cac3b51d32..a14ae3a919 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -9,19 +9,63 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) func TestReduce_AllFunc(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - segmentID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + + segmentID := UniqueID(0) segment := newSegment(collection, segmentID) + assert.Equal(t, segmentID, segment.segmentID) const DIM = 16 var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} diff --git a/internal/querynode/search_service.go b/internal/querynode/search_service.go index 3101b6dd92..a5cda00534 100644 --- a/internal/querynode/search_service.go +++ b/internal/querynode/search_service.go @@ -4,12 +4,10 @@ import "C" import ( "context" "errors" - "fmt" + "github.com/golang/protobuf/proto" "log" "sync" - "github.com/golang/protobuf/proto" - "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" @@ -21,7 +19,7 @@ type searchService struct { wait sync.WaitGroup cancel context.CancelFunc - replica collectionReplica + replica *collectionReplica tSafeWatcher *tSafeWatcher serviceableTime Timestamp @@ -29,14 +27,13 @@ type searchService struct { msgBuffer chan msgstream.TsMsg unsolvedMsg []msgstream.TsMsg - searchMsgStream msgstream.MsgStream - searchResultMsgStream msgstream.MsgStream - queryNodeID UniqueID + searchMsgStream *msgstream.MsgStream + searchResultMsgStream *msgstream.MsgStream } type ResultEntityIds []UniqueID -func newSearchService(ctx context.Context, replica collectionReplica) *searchService { +func newSearchService(ctx context.Context, replica *collectionReplica) *searchService { receiveBufSize := Params.searchReceiveBufSize() pulsarBufSize := Params.searchPulsarBufSize() @@ -72,15 +69,14 @@ func newSearchService(ctx context.Context, replica collectionReplica) *searchSer replica: replica, tSafeWatcher: newTSafeWatcher(), - searchMsgStream: inputStream, - searchResultMsgStream: outputStream, - queryNodeID: Params.QueryNodeID(), + searchMsgStream: &inputStream, + searchResultMsgStream: &outputStream, } } func (ss *searchService) start() { - ss.searchMsgStream.Start() - ss.searchResultMsgStream.Start() + (*ss.searchMsgStream).Start() + (*ss.searchResultMsgStream).Start() ss.register() ss.wait.Add(2) go ss.receiveSearchMsg() @@ -89,24 +85,20 @@ func (ss *searchService) start() { } func (ss *searchService) close() { - if ss.searchMsgStream != nil { - ss.searchMsgStream.Close() - } - if ss.searchResultMsgStream != nil { - ss.searchResultMsgStream.Close() - } + (*ss.searchMsgStream).Close() + (*ss.searchResultMsgStream).Close() ss.cancel() } func (ss *searchService) register() { - tSafe := ss.replica.getTSafe() - tSafe.registerTSafeWatcher(ss.tSafeWatcher) + tSafe := (*(ss.replica)).getTSafe() + (*tSafe).registerTSafeWatcher(ss.tSafeWatcher) } func (ss *searchService) waitNewTSafe() Timestamp { // block until dataSyncService updating tSafe ss.tSafeWatcher.hasUpdate() - timestamp := ss.replica.getTSafe().get() + timestamp := (*(*ss.replica).getTSafe()).get() return timestamp } @@ -130,7 +122,7 @@ func (ss *searchService) receiveSearchMsg() { case <-ss.ctx.Done(): return default: - msgPack := ss.searchMsgStream.Consume() + msgPack := (*ss.searchMsgStream).Consume() if msgPack == nil || len(msgPack.Msgs) <= 0 { continue } @@ -227,7 +219,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } collectionName := query.CollectionName partitionTags := query.PartitionTags - collection, err := ss.replica.getCollectionByName(collectionName) + collection, err := (*ss.replica).getCollectionByName(collectionName) if err != nil { return err } @@ -249,14 +241,14 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { matchedSegments := make([]*Segment, 0) for _, partitionTag := range partitionTags { - hasPartition := ss.replica.hasPartition(collectionID, partitionTag) + hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag) if !hasPartition { return errors.New("search Failed, invalid partitionTag") } } for _, partitionTag := range partitionTags { - partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag) + partition, _ := (*ss.replica).getPartitionByTag(collectionID, partitionTag) for _, segment := range partition.segments { //fmt.Println("dsl = ", dsl) @@ -276,13 +268,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, ReqID: searchMsg.ReqID, ProxyID: searchMsg.ProxyID, - QueryNodeID: ss.queryNodeID, + QueryNodeID: searchMsg.ProxyID, Timestamp: searchTimestamp, ResultChannelID: searchMsg.ResultChannelID, Hits: nil, } searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, SearchResult: results, } err = ss.publishSearchResult(searchResultMsg) @@ -341,7 +333,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { Hits: hits, } searchResultMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, SearchResult: results, } err = ss.publishSearchResult(searchResultMsg) @@ -358,10 +350,9 @@ func (ss *searchService) search(msg msgstream.TsMsg) error { } func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error { - fmt.Println("Public SearchResult", msg.HashKeys()) msgPack := msgstream.MsgPack{} msgPack.Msgs = append(msgPack.Msgs, msg) - err := ss.searchResultMsgStream.Produce(&msgPack) + err := (*ss.searchResultMsgStream).Produce(&msgPack) if err != nil { return err } @@ -386,11 +377,11 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s } tsMsg := &msgstream.SearchResultMsg{ - BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}}, + BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}}, SearchResult: results, } msgPack.Msgs = append(msgPack.Msgs, tsMsg) - err := ss.searchResultMsgStream.Produce(&msgPack) + err := (*ss.searchResultMsgStream).Produce(&msgPack) if err != nil { return err } diff --git a/internal/querynode/search_service_test.go b/internal/querynode/search_service_test.go index 2ce122a19e..3624e9b99e 100644 --- a/internal/querynode/search_service_test.go +++ b/internal/querynode/search_service_test.go @@ -13,15 +13,80 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) func TestSearch_Search(t *testing.T) { - node := NewQueryNode(context.Background(), 0) - initTestMeta(t, node, "collection0", 0, 0) + Params.Init() + ctx, cancel := context.WithCancel(context.Background()) + // init query node pulsarURL, _ := Params.pulsarAddress() + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) // test data generate const msgLength = 10 @@ -93,14 +158,14 @@ func TestSearch_Search(t *testing.T) { msgPackSearch := msgstream.MsgPack{} msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) - searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) searchStream.SetPulsarClient(pulsarURL) searchStream.CreatePulsarProducers(searchProducerChannels) searchStream.Start() err = searchStream.Produce(&msgPackSearch) assert.NoError(t, err) - node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica) + node.searchService = newSearchService(node.ctx, node.replica) go node.searchService.start() // start insert @@ -170,7 +235,7 @@ func TestSearch_Search(t *testing.T) { timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) // pulsar produce - insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) insertStream.SetPulsarClient(pulsarURL) insertStream.CreatePulsarProducers(insertProducerChannels) insertStream.Start() @@ -180,19 +245,83 @@ func TestSearch_Search(t *testing.T) { assert.NoError(t, err) // dataSync - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica) + node.dataSyncService = newDataSyncService(node.ctx, node.replica) go node.dataSyncService.start() time.Sleep(1 * time.Second) + cancel() node.Close() } func TestSearch_SearchMultiSegments(t *testing.T) { - node := NewQueryNode(context.Background(), 0) - initTestMeta(t, node, "collection0", 0, 0) + Params.Init() + ctx, cancel := context.WithCancel(context.Background()) + // init query node pulsarURL, _ := Params.pulsarAddress() + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) // test data generate const msgLength = 1024 @@ -264,14 +393,14 @@ func TestSearch_SearchMultiSegments(t *testing.T) { msgPackSearch := msgstream.MsgPack{} msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) - searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) searchStream.SetPulsarClient(pulsarURL) searchStream.CreatePulsarProducers(searchProducerChannels) searchStream.Start() err = searchStream.Produce(&msgPackSearch) assert.NoError(t, err) - node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica) + node.searchService = newSearchService(node.ctx, node.replica) go node.searchService.start() // start insert @@ -345,7 +474,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) { timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) // pulsar produce - insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) insertStream.SetPulsarClient(pulsarURL) insertStream.CreatePulsarProducers(insertProducerChannels) insertStream.Start() @@ -355,10 +484,11 @@ func TestSearch_SearchMultiSegments(t *testing.T) { assert.NoError(t, err) // dataSync - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica) + node.dataSyncService = newDataSyncService(node.ctx, node.replica) go node.dataSyncService.start() time.Sleep(1 * time.Second) + cancel() node.Close() } diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 63743a35c6..4522289cac 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -1,6 +1,7 @@ package querynode import ( + "context" "encoding/binary" "log" "math" @@ -9,21 +10,61 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) //-------------------------------------------------------------------------------------- constructor and destructor func TestSegment_newSegment(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -33,15 +74,52 @@ func TestSegment_newSegment(t *testing.T) { } func TestSegment_deleteSegment(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -53,15 +131,52 @@ func TestSegment_deleteSegment(t *testing.T) { //-------------------------------------------------------------------------------------- stats functions func TestSegment_getRowCount(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -104,15 +219,52 @@ func TestSegment_getRowCount(t *testing.T) { } func TestSegment_getDeletedCount(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -161,15 +313,52 @@ func TestSegment_getDeletedCount(t *testing.T) { } func TestSegment_getMemSize(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -213,15 +402,53 @@ func TestSegment_getMemSize(t *testing.T) { //-------------------------------------------------------------------------------------- dm & search functions func TestSegment_segmentInsert(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + segmentID := UniqueID(0) segment := newSegment(collection, segmentID) assert.Equal(t, segmentID, segment.segmentID) @@ -259,15 +486,52 @@ func TestSegment_segmentInsert(t *testing.T) { } func TestSegment_segmentDelete(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -312,15 +576,55 @@ func TestSegment_segmentDelete(t *testing.T) { } func TestSegment_segmentSearch(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -357,6 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) { dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + pulsarURL, _ := Params.pulsarAddress() + const receiveBufSize = 1024 + searchProducerChannels := Params.searchChannelNames() + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchStream.SetPulsarClient(pulsarURL) + searchStream.CreatePulsarProducers(searchProducerChannels) + var searchRawData []byte for _, ele := range vec { buf := make([]byte, 4) @@ -397,15 +708,52 @@ func TestSegment_segmentSearch(t *testing.T) { //-------------------------------------------------------------------------------------- preDm functions func TestSegment_segmentPreInsert(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) @@ -439,15 +787,52 @@ func TestSegment_segmentPreInsert(t *testing.T) { } func TestSegment_segmentPreDelete(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(0) - collectionMeta := genTestCollectionMeta(collectionName, collectionID) - collectionMetaBlob := proto.MarshalTextString(collectionMeta) + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - collection := newCollection(collectionMeta, collectionMetaBlob) - assert.Equal(t, collection.meta.Schema.Name, collectionName) - assert.Equal(t, collection.meta.ID, collectionID) + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) segmentID := UniqueID(0) segment := newSegment(collection, segmentID) diff --git a/internal/querynode/stats_service.go b/internal/querynode/stats_service.go index 440bc360a3..0ded253326 100644 --- a/internal/querynode/stats_service.go +++ b/internal/querynode/stats_service.go @@ -14,11 +14,11 @@ import ( type statsService struct { ctx context.Context - statsStream msgstream.MsgStream - replica collectionReplica + statsStream *msgstream.MsgStream + replica *collectionReplica } -func newStatsService(ctx context.Context, replica collectionReplica) *statsService { +func newStatsService(ctx context.Context, replica *collectionReplica) *statsService { return &statsService{ ctx: ctx, @@ -44,8 +44,8 @@ func (sService *statsService) start() { var statsMsgStream msgstream.MsgStream = statsStream - sService.statsStream = statsMsgStream - sService.statsStream.Start() + sService.statsStream = &statsMsgStream + (*sService.statsStream).Start() // start service fmt.Println("do segments statistic in ", strconv.Itoa(sleepTimeInterval), "ms") @@ -60,13 +60,11 @@ func (sService *statsService) start() { } func (sService *statsService) close() { - if sService.statsStream != nil { - sService.statsStream.Close() - } + (*sService.statsStream).Close() } func (sService *statsService) sendSegmentStatistic() { - statisticData := sService.replica.getSegmentStatistics() + statisticData := (*sService.replica).getSegmentStatistics() // fmt.Println("Publish segment statistic") // fmt.Println(statisticData) @@ -84,7 +82,7 @@ func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSeg var msgPack = msgstream.MsgPack{ Msgs: []msgstream.TsMsg{msg}, } - err := sService.statsStream.Produce(&msgPack) + err := (*sService.statsStream).Produce(&msgPack) if err != nil { log.Println(err) } diff --git a/internal/querynode/stats_service_test.go b/internal/querynode/stats_service_test.go index 3e6005a463..68347928a6 100644 --- a/internal/querynode/stats_service_test.go +++ b/internal/querynode/stats_service_test.go @@ -1,42 +1,193 @@ package querynode import ( + "context" "testing" + "time" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) // NOTE: start pulsar before test func TestStatsService_start(t *testing.T) { - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) - node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica) + Params.Init() + var ctx context.Context + + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init query node + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) + + // start stats service + node.statsService = newStatsService(node.ctx, node.replica) node.statsService.start() - node.Close() } -//NOTE: start pulsar before test +// NOTE: start pulsar before test func TestSegmentManagement_SegmentStatisticService(t *testing.T) { - node := newQueryNode() - initTestMeta(t, node, "collection0", 0, 0) + Params.Init() + var ctx context.Context + + if closeWithDeadline { + var cancel context.CancelFunc + d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, cancel = context.WithDeadline(context.Background(), d) + defer cancel() + } else { + ctx = context.Background() + } + + // init query node + pulsarURL, _ := Params.pulsarAddress() + node := NewQueryNode(ctx, 0) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + IsPrimaryKey: false, + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + IsPrimaryKey: false, + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.replica).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.replica).getCollectionNum(), 1) + + err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) const receiveBufSize = 1024 // start pulsar producerChannels := []string{Params.statsChannelName()} - pulsarURL, _ := Params.pulsarAddress() - - statsStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + statsStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) statsStream.SetPulsarClient(pulsarURL) statsStream.CreatePulsarProducers(producerChannels) var statsMsgStream msgstream.MsgStream = statsStream - node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica) - node.statsService.statsStream = statsMsgStream - node.statsService.statsStream.Start() + node.statsService = newStatsService(node.ctx, node.replica) + node.statsService.statsStream = &statsMsgStream + (*node.statsService.statsStream).Start() // send stats node.statsService.sendSegmentStatistic() - node.Close() } diff --git a/internal/querynode/tsafe.go b/internal/querynode/tsafe.go index 27a1b64004..bbafc10e46 100644 --- a/internal/querynode/tsafe.go +++ b/internal/querynode/tsafe.go @@ -36,11 +36,11 @@ type tSafeImpl struct { watcherList []*tSafeWatcher } -func newTSafe() tSafe { +func newTSafe() *tSafe { var t tSafe = &tSafeImpl{ watcherList: make([]*tSafeWatcher, 0), } - return t + return &t } func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) { diff --git a/internal/querynode/tsafe_test.go b/internal/querynode/tsafe_test.go index 1ae166f7f7..e02b915072 100644 --- a/internal/querynode/tsafe_test.go +++ b/internal/querynode/tsafe_test.go @@ -9,13 +9,13 @@ import ( func TestTSafe_GetAndSet(t *testing.T) { tSafe := newTSafe() watcher := newTSafeWatcher() - tSafe.registerTSafeWatcher(watcher) + (*tSafe).registerTSafeWatcher(watcher) go func() { watcher.hasUpdate() - timestamp := tSafe.get() + timestamp := (*tSafe).get() assert.Equal(t, timestamp, Timestamp(1000)) }() - tSafe.set(Timestamp(1000)) + (*tSafe).set(Timestamp(1000)) } diff --git a/internal/storage/cwrapper/.gitignore b/internal/storage/cwrapper/.gitignore index 87a73e597c..d663cac4d6 100644 --- a/internal/storage/cwrapper/.gitignore +++ b/internal/storage/cwrapper/.gitignore @@ -1,4 +1,3 @@ -output cmake-build-debug .idea cmake_build diff --git a/internal/storage/cwrapper/CMakeLists.txt b/internal/storage/cwrapper/CMakeLists.txt index e38f300fa2..387274044e 100644 --- a/internal/storage/cwrapper/CMakeLists.txt +++ b/internal/storage/cwrapper/CMakeLists.txt @@ -2,17 +2,6 @@ cmake_minimum_required(VERSION 3.14...3.17 FATAL_ERROR) project(wrapper) set(CMAKE_CXX_STANDARD 17) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -if (NOT GIT_ARROW_REPO) - set(GIT_ARROW_REPO "https://github.com/apache/arrow.git") -endif () -message(STATUS "Arrow Repo:" ${GIT_ARROW_REPO}) - -if (NOT GIT_ARROW_TAG) - set(GIT_ARROW_TAG "apache-arrow-2.0.0") -endif () -message(STATUS "Arrow Tag:" ${GIT_ARROW_TAG}) ################################################################################################### # - cmake modules --------------------------------------------------------------------------------- @@ -25,39 +14,29 @@ set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/" ${CMAKE_MODUL message(STATUS "BUILDING ARROW") include(ConfigureArrow) -if (ARROW_FOUND) +if(ARROW_FOUND) message(STATUS "Apache Arrow found in ${ARROW_INCLUDE_DIR}") -else () +else() message(FATAL_ERROR "Apache Arrow not found, please check your settings.") -endif (ARROW_FOUND) +endif(ARROW_FOUND) add_library(arrow STATIC IMPORTED ${ARROW_LIB}) add_library(parquet STATIC IMPORTED ${PARQUET_LIB}) add_library(thrift STATIC IMPORTED ${THRIFT_LIB}) add_library(utf8proc STATIC IMPORTED ${UTF8PROC_LIB}) -if (ARROW_FOUND) +if(ARROW_FOUND) set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${ARROW_LIB}) set_target_properties(parquet PROPERTIES IMPORTED_LOCATION ${PARQUET_LIB}) set_target_properties(thrift PROPERTIES IMPORTED_LOCATION ${THRIFT_LIB}) set_target_properties(utf8proc PROPERTIES IMPORTED_LOCATION ${UTF8PROC_LIB}) -endif (ARROW_FOUND) +endif(ARROW_FOUND) ################################################################################################### include_directories(${ARROW_INCLUDE_DIR}) include_directories(${PROJECT_SOURCE_DIR}) -add_library(wrapper STATIC) -target_sources(wrapper PUBLIC ParquetWrapper.cpp - PayloadStream.cpp) - -target_link_libraries(wrapper PUBLIC parquet arrow thrift utf8proc pthread) - -if(NOT CMAKE_INSTALL_PREFIX) - set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}) -endif() -install(TARGETS wrapper DESTINATION ${CMAKE_INSTALL_PREFIX}) -install(FILES ${ARROW_LIB} ${PARQUET_LIB} ${THRIFT_LIB} ${UTF8PROC_LIB} DESTINATION ${CMAKE_INSTALL_PREFIX}) +add_library(wrapper ParquetWrapper.cpp ParquetWrapper.h ColumnType.h PayloadStream.h PayloadStream.cpp) add_subdirectory(test) \ No newline at end of file diff --git a/internal/storage/cwrapper/ParquetWrapper.cpp b/internal/storage/cwrapper/ParquetWrapper.cpp index c586597286..9061f07f02 100644 --- a/internal/storage/cwrapper/ParquetWrapper.cpp +++ b/internal/storage/cwrapper/ParquetWrapper.cpp @@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_ st.error_msg = ErrorMsg("payload has finished"); return st; } - auto ast = builder->AppendValues(values, length); + auto ast = builder->AppendValues(values, (dimension / 8) * length); if (!ast.ok()) { st.error_code = static_cast(ErrorCode::UNEXPECTED_ERROR); st.error_msg = ErrorMsg(ast.message()); @@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float * st.error_msg = ErrorMsg("payload has finished"); return st; } - auto ast = builder->AppendValues(reinterpret_cast(values), length); + auto ast = builder->AppendValues(reinterpret_cast(values), dimension * length * sizeof(float)); if (!ast.ok()) { st.error_code = static_cast(ErrorCode::UNEXPECTED_ERROR); st.error_msg = ErrorMsg(ast.message()); @@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, return st; } *dimension = array->byte_width() * 8; - *length = array->length(); + *length = array->length() / array->byte_width(); *values = (uint8_t *) array->raw_values(); return st; } @@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, return st; } *dimension = array->byte_width() / sizeof(float); - *length = array->length(); + *length = array->length() / array->byte_width(); *values = (float *) array->raw_values(); return st; } @@ -478,7 +478,12 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) { auto p = reinterpret_cast(payloadReader); if (p->array == nullptr) return 0; - return p->array->length(); + auto ba = std::dynamic_pointer_cast(p->array); + if (ba == nullptr) { + return p->array->length(); + } else { + return ba->length() / ba->byte_width(); + } } extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) { diff --git a/internal/storage/cwrapper/ParquetWrapper.h b/internal/storage/cwrapper/ParquetWrapper.h index 748894e861..d7475cde2e 100644 --- a/internal/storage/cwrapper/ParquetWrapper.h +++ b/internal/storage/cwrapper/ParquetWrapper.h @@ -5,7 +5,6 @@ extern "C" { #endif #include -#include typedef void *CPayloadWriter; @@ -56,4 +55,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader); #ifdef __cplusplus } -#endif +#endif \ No newline at end of file diff --git a/internal/storage/cwrapper/build.sh b/internal/storage/cwrapper/build.sh deleted file mode 100755 index 8989401957..0000000000 --- a/internal/storage/cwrapper/build.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -SOURCE=${BASH_SOURCE[0]} -while [ -h $SOURCE ]; do # resolve $SOURCE until the file is no longer a symlink - DIR=$( cd -P $( dirname $SOURCE ) && pwd ) - SOURCE=$(readlink $SOURCE) - [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located -done -DIR=$( cd -P $( dirname $SOURCE ) && pwd ) -# echo $DIR - -CMAKE_BUILD=${DIR}/cmake_build -OUTPUT_LIB=${DIR}/output - -if [ ! -d ${CMAKE_BUILD} ];then - mkdir ${CMAKE_BUILD} -fi - -if [ -d ${OUTPUT_LIB} ];then - rm -rf ${OUTPUT_LIB} -fi -mkdir ${OUTPUT_LIB} - -BUILD_TYPE="Debug" -GIT_ARROW_REPO="https://github.com/apache/arrow.git" -GIT_ARROW_TAG="apache-arrow-2.0.0" - -while getopts "a:b:t:h" arg; do - case $arg in - t) - BUILD_TYPE=$OPTARG # BUILD_TYPE - ;; - a) - GIT_ARROW_REPO=$OPTARG - ;; - b) - GIT_ARROW_TAG=$OPTARG - ;; - h) # help - echo "-t: build type(default: Debug) --a: arrow repo(default: https://github.com/apache/arrow.git) --b: arrow tag(default: apache-arrow-2.0.0) --h: help - " - exit 0 - ;; - ?) - echo "ERROR! unknown argument" - exit 1 - ;; - esac -done -echo "BUILD_TYPE: " $BUILD_TYPE -echo "GIT_ARROW_REPO: " $GIT_ARROW_REPO -echo "GIT_ARROW_TAG: " $GIT_ARROW_TAG - -pushd ${CMAKE_BUILD} -cmake -DCMAKE_INSTALL_PREFIX=${OUTPUT_LIB} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DGIT_ARROW_REPO=${GIT_ARROW_REPO} -DGIT_ARROW_TAG=${GIT_ARROW_TAG} .. && make && make install diff --git a/internal/storage/cwrapper/cmake/Modules/ConfigureArrow.cmake b/internal/storage/cwrapper/cmake/Modules/ConfigureArrow.cmake index 8c0e19a371..6cd9003a6d 100644 --- a/internal/storage/cwrapper/cmake/Modules/ConfigureArrow.cmake +++ b/internal/storage/cwrapper/cmake/Modules/ConfigureArrow.cmake @@ -35,14 +35,11 @@ if(ARROW_CONFIG) message(FATAL_ERROR "Configuring Arrow failed: " ${ARROW_CONFIG}) endif(ARROW_CONFIG) -#set(PARALLEL_BUILD -j) -#if($ENV{PARALLEL_LEVEL}) -# set(NUM_JOBS $ENV{PARALLEL_LEVEL}) -# set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}") -#endif($ENV{PARALLEL_LEVEL}) -set(NUM_JOBS 4) -set(PARALLEL_BUILD "-j${NUM_JOBS}") - +set(PARALLEL_BUILD -j) +if($ENV{PARALLEL_LEVEL}) + set(NUM_JOBS $ENV{PARALLEL_LEVEL}) + set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}") +endif($ENV{PARALLEL_LEVEL}) if(${NUM_JOBS}) if(${NUM_JOBS} EQUAL 1) @@ -91,8 +88,8 @@ if(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB) set(ARROW_FOUND TRUE) endif(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB) -# message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT}) -# set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include") -# set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib") +message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT}) +set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include") +set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib") add_definitions(-DARROW_METADATA_V4) diff --git a/internal/storage/cwrapper/cmake/Templates/Arrow.CMakeLists.txt.cmake b/internal/storage/cwrapper/cmake/Templates/Arrow.CMakeLists.txt.cmake index 0eb3da4088..4ba5ce6aee 100644 --- a/internal/storage/cwrapper/cmake/Templates/Arrow.CMakeLists.txt.cmake +++ b/internal/storage/cwrapper/cmake/Templates/Arrow.CMakeLists.txt.cmake @@ -20,8 +20,8 @@ project(wrapper-Arrow) include(ExternalProject) ExternalProject_Add(Arrow - GIT_REPOSITORY ${GIT_ARROW_REPO} - GIT_TAG ${GIT_ARROW_TAG} + GIT_REPOSITORY https://github.com/apache/arrow.git + GIT_TAG apache-arrow-2.0.0 GIT_SHALLOW true SOURCE_DIR "${ARROW_ROOT}/arrow" SOURCE_SUBDIR "cpp" diff --git a/internal/storage/cwrapper/test/CMakeLists.txt b/internal/storage/cwrapper/test/CMakeLists.txt index 3ee32d277b..09f9020a4b 100644 --- a/internal/storage/cwrapper/test/CMakeLists.txt +++ b/internal/storage/cwrapper/test/CMakeLists.txt @@ -14,8 +14,6 @@ target_link_libraries(wrapper_test parquet arrow thrift utf8proc pthread ) -install(TARGETS wrapper_test DESTINATION ${CMAKE_INSTALL_PREFIX}) - # Defines `gtest_discover_tests()`. #include(GoogleTest) #gtest_discover_tests(milvusd_test) \ No newline at end of file diff --git a/internal/storage/cwrapper/test/ParquetWrapperTest.cpp b/internal/storage/cwrapper/test/ParquetWrapperTest.cpp index c598185a14..0376161b16 100644 --- a/internal/storage/cwrapper/test/ParquetWrapperTest.cpp +++ b/internal/storage/cwrapper/test/ParquetWrapperTest.cpp @@ -71,36 +71,36 @@ TEST(wrapper, inoutstream) { } TEST(wrapper, boolean) { - auto payload = NewPayloadWriter(ColumnType::BOOL); - bool data[] = {true, false, true, false}; + auto payload = NewPayloadWriter(ColumnType::BOOL); + bool data[] = {true, false, true, false}; - auto st = AddBooleanToPayload(payload, data, 4); - ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); - st = FinishPayloadWriter(payload); - ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); - auto cb = GetPayloadBufferFromWriter(payload); - ASSERT_GT(cb.length, 0); - ASSERT_NE(cb.data, nullptr); - auto nums = GetPayloadLengthFromWriter(payload); - ASSERT_EQ(nums, 4); + auto st = AddBooleanToPayload(payload, data, 4); + ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); + st = FinishPayloadWriter(payload); + ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); + auto cb = GetPayloadBufferFromWriter(payload); + ASSERT_GT(cb.length, 0); + ASSERT_NE(cb.data, nullptr); + auto nums = GetPayloadLengthFromWriter(payload); + ASSERT_EQ(nums, 4); - auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length); - bool *values; - int length; - st = GetBoolFromPayload(reader, &values, &length); - ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); - ASSERT_NE(values, nullptr); - ASSERT_EQ(length, 4); - length = GetPayloadLengthFromReader(reader); - ASSERT_EQ(length, 4); - for (int i = 0; i < length; i++) { - ASSERT_EQ(data[i], values[i]); - } + auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length); + bool *values; + int length; + st = GetBoolFromPayload(reader, &values, &length); + ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); + ASSERT_NE(values, nullptr); + ASSERT_EQ(length, 4); + length = GetPayloadLengthFromReader(reader); + ASSERT_EQ(length, 4); + for (int i = 0; i < length; i++) { + ASSERT_EQ(data[i], values[i]); + } - st = ReleasePayloadWriter(payload); - ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); - st = ReleasePayloadReader(reader); - ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); + st = ReleasePayloadWriter(payload); + ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::SUCCESS); } #define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \ diff --git a/internal/storage/payload.go b/internal/storage/payload.go deleted file mode 100644 index 3240daa6fa..0000000000 --- a/internal/storage/payload.go +++ /dev/null @@ -1,626 +0,0 @@ -package storage - -/* -#cgo CFLAGS: -I${SRCDIR}/cwrapper - -#cgo LDFLAGS: -L${SRCDIR}/cwrapper/output -l:libwrapper.a -l:libparquet.a -l:libarrow.a -l:libthrift.a -l:libutf8proc.a -lstdc++ -lm -#include -#include "ParquetWrapper.h" -*/ -import "C" -import ( - "unsafe" - - "github.com/zilliztech/milvus-distributed/internal/errors" - "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" -) - -type ( - PayloadWriter struct { - payloadWriterPtr C.CPayloadWriter - colType schemapb.DataType - } - - PayloadReader struct { - payloadReaderPtr C.CPayloadReader - colType schemapb.DataType - } -) - -func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) { - w := C.NewPayloadWriter(C.int(colType)) - if w == nil { - return nil, errors.New("create Payload writer failed") - } - return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil -} - -func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error { - switch len(dim) { - case 0: - switch w.colType { - case schemapb.DataType_BOOL: - val, ok := msgs.([]bool) - if !ok { - return errors.New("incorrect data type") - } - return w.AddBoolToPayload(val) - - case schemapb.DataType_INT8: - val, ok := msgs.([]int8) - if !ok { - return errors.New("incorrect data type") - } - return w.AddInt8ToPayload(val) - - case schemapb.DataType_INT16: - val, ok := msgs.([]int16) - if !ok { - return errors.New("incorrect data type") - } - return w.AddInt16ToPayload(val) - - case schemapb.DataType_INT32: - val, ok := msgs.([]int32) - if !ok { - return errors.New("incorrect data type") - } - return w.AddInt32ToPayload(val) - - case schemapb.DataType_INT64: - val, ok := msgs.([]int64) - if !ok { - return errors.New("incorrect data type") - } - return w.AddInt64ToPayload(val) - - case schemapb.DataType_FLOAT: - val, ok := msgs.([]float32) - if !ok { - return errors.New("incorrect data type") - } - return w.AddFloatToPayload(val) - - case schemapb.DataType_DOUBLE: - val, ok := msgs.([]float64) - if !ok { - return errors.New("incorrect data type") - } - return w.AddDoubleToPayload(val) - - case schemapb.DataType_STRING: - val, ok := msgs.(string) - if !ok { - return errors.New("incorrect data type") - } - return w.AddOneStringToPayload(val) - } - case 1: - switch w.colType { - case schemapb.DataType_VECTOR_BINARY: - val, ok := msgs.([]byte) - if !ok { - return errors.New("incorrect data type") - } - return w.AddBinaryVectorToPayload(val, dim[0]) - - case schemapb.DataType_VECTOR_FLOAT: - val, ok := msgs.([]float32) - if !ok { - return errors.New("incorrect data type") - } - return w.AddFloatVectorToPayload(val, dim[0]) - } - - default: - return errors.New("incorrect input numbers") - - } - return nil -} - -func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.float)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error { - length := len(msgs) - if length <= 0 { - return errors.Errorf("can't add empty msgs into payload") - } - - cMsgs := (*C.double)(unsafe.Pointer(&msgs[0])) - cLength := C.int(length) - - status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) AddOneStringToPayload(msg string) error { - length := len(msg) - if length == 0 { - return errors.New("can't add empty string into payload") - } - - cmsg := C.CString(msg) - clength := C.int(length) - defer C.free(unsafe.Pointer(cmsg)) - - st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength) - - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -// dimension > 0 && (%8 == 0) -func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error { - length := len(binVec) - if length <= 0 { - return errors.New("can't add empty binVec into payload") - } - - if dim <= 0 { - return errors.New("dimension should be greater than 0") - } - - cBinVec := (*C.uint8_t)(&binVec[0]) - cDim := C.int(dim) - cLength := C.int(length / (dim / 8)) - - st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -// dimension > 0 && (%8 == 0) -func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error { - length := len(floatVec) - if length <= 0 { - return errors.New("can't add empty floatVec into payload") - } - - if dim <= 0 { - return errors.New("dimension should be greater than 0") - } - - cBinVec := (*C.float)(&floatVec[0]) - cDim := C.int(dim) - cLength := C.int(length / dim) - - st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) FinishPayloadWriter() error { - st := C.FinishPayloadWriter(w.payloadWriterPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) { - cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr) - pointer := unsafe.Pointer(cb.data) - length := int(cb.length) - if length <= 0 { - return nil, errors.New("empty buffer") - } - // refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices - slice := (*[1 << 28]byte)(pointer)[:length:length] - return slice, nil -} - -func (w *PayloadWriter) GetPayloadLengthFromWriter() (int, error) { - length := C.GetPayloadLengthFromWriter(w.payloadWriterPtr) - return int(length), nil -} - -func (w *PayloadWriter) ReleasePayloadWriter() error { - st := C.ReleasePayloadWriter(w.payloadWriterPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -func (w *PayloadWriter) Close() error { - return w.ReleasePayloadWriter() -} - -func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) { - if len(buf) == 0 { - return nil, errors.New("create Payload reader failed, buffer is empty") - } - r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf))) - return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil -} - -// Params: -// `idx`: String index -// Return: -// `interface{}`: all types. -// `int`: length, only meaningful to FLOAT/BINARY VECTOR type. -// `error`: error. -func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) { - switch len(idx) { - case 1: - switch r.colType { - case schemapb.DataType_STRING: - val, err := r.GetOneStringFromPayload(idx[0]) - return val, 0, err - } - case 0: - switch r.colType { - case schemapb.DataType_BOOL: - val, err := r.GetBoolFromPayload() - return val, 0, err - - case schemapb.DataType_INT8: - val, err := r.GetInt8FromPayload() - return val, 0, err - - case schemapb.DataType_INT16: - val, err := r.GetInt16FromPayload() - return val, 0, err - - case schemapb.DataType_INT32: - val, err := r.GetInt32FromPayload() - return val, 0, err - - case schemapb.DataType_INT64: - val, err := r.GetInt64FromPayload() - return val, 0, err - - case schemapb.DataType_FLOAT: - val, err := r.GetFloatFromPayload() - return val, 0, err - - case schemapb.DataType_DOUBLE: - val, err := r.GetDoubleFromPayload() - return val, 0, err - - case schemapb.DataType_VECTOR_BINARY: - return r.GetBinaryVectorFromPayload() - - case schemapb.DataType_VECTOR_FLOAT: - return r.GetFloatVectorFromPayload() - default: - return nil, 0, errors.New("Unknown type") - } - default: - return nil, 0, errors.New("incorrect number of index") - } - - return nil, 0, errors.New("unknown error") -} - -func (r *PayloadReader) ReleasePayloadReader() error { - st := C.ReleasePayloadReader(r.payloadReaderPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil -} - -func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) { - var cMsg *C.bool - var cSize C.int - - st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) { - var cMsg *C.int8_t - var cSize C.int - - st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) { - var cMsg *C.int16_t - var cSize C.int - - st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) { - var cMsg *C.int32_t - var cSize C.int - - st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) { - var cMsg *C.int64_t - var cSize C.int - - st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) { - var cMsg *C.float - var cSize C.int - - st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) { - var cMsg *C.double - var cSize C.int - - st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) - } - - slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize] - return slice, nil -} - -func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) { - var cStr *C.char - var cSize C.int - - st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize) - - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return "", errors.New(msg) - } - return C.GoStringN(cStr, cSize), nil -} - -// ,dimension, error -func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) { - var cMsg *C.uint8_t - var cDim C.int - var cLen C.int - - st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, 0, errors.New(msg) - } - length := (cDim / 8) * cLen - - slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length] - return slice, int(cDim), nil -} - -// ,dimension, error -func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { - var cMsg *C.float - var cDim C.int - var cLen C.int - - st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_SUCCESS { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, 0, errors.New(msg) - } - length := cDim * cLen - - slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length] - return slice, int(cDim), nil -} - -func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) { - length := C.GetPayloadLengthFromReader(r.payloadReaderPtr) - return int(length), nil -} - -func (r *PayloadReader) Close() error { - return r.ReleasePayloadReader() -} diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go deleted file mode 100644 index 94150fef1d..0000000000 --- a/internal/storage/payload_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package storage - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" -) - -func TestPayload_ReaderandWriter(t *testing.T) { - - t.Run("TestBool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BOOL) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddBoolToPayload([]bool{false, false, false, false}) - assert.Nil(t, err) - err = w.AddDataToPayload([]bool{false, false, false, false}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 8, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 8) - bools, err := r.GetBoolFromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools) - ibools, _, err := r.GetDataFromPayload() - bools = ibools.([]bool) - assert.Nil(t, err) - assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools) - defer r.ReleasePayloadReader() - - }) - - t.Run("TestInt8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_INT8) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddInt8ToPayload([]int8{1, 2, 3}) - assert.Nil(t, err) - err = w.AddDataToPayload([]int8{4, 5, 6}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_INT8, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - - int8s, err := r.GetInt8FromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s) - - iint8s, _, err := r.GetDataFromPayload() - int8s = iint8s.([]int8) - assert.Nil(t, err) - - assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestInt16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_INT16) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddInt16ToPayload([]int16{1, 2, 3}) - assert.Nil(t, err) - err = w.AddDataToPayload([]int16{1, 2, 3}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_INT16, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - int16s, err := r.GetInt16FromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s) - - iint16s, _, err := r.GetDataFromPayload() - int16s = iint16s.([]int16) - assert.Nil(t, err) - assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestInt32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_INT32) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddInt32ToPayload([]int32{1, 2, 3}) - assert.Nil(t, err) - err = w.AddDataToPayload([]int32{1, 2, 3}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_INT32, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - - int32s, err := r.GetInt32FromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s) - - iint32s, _, err := r.GetDataFromPayload() - int32s = iint32s.([]int32) - assert.Nil(t, err) - assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestInt64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_INT64) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddInt64ToPayload([]int64{1, 2, 3}) - assert.Nil(t, err) - err = w.AddDataToPayload([]int64{1, 2, 3}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_INT64, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - - int64s, err := r.GetInt64FromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s) - - iint64s, _, err := r.GetDataFromPayload() - int64s = iint64s.([]int64) - assert.Nil(t, err) - assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestFloat32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FLOAT) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0}) - assert.Nil(t, err) - err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - - float32s, err := r.GetFloatFromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s) - - ifloat32s, _, err := r.GetDataFromPayload() - float32s = ifloat32s.([]float32) - assert.Nil(t, err) - assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestDouble", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_DOUBLE) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0}) - assert.Nil(t, err) - err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0}) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 6, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 6) - - float64s, err := r.GetDoubleFromPayload() - assert.Nil(t, err) - assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s) - - ifloat64s, _, err := r.GetDataFromPayload() - float64s = ifloat64s.([]float64) - assert.Nil(t, err) - assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s) - defer r.ReleasePayloadReader() - }) - - t.Run("TestAddOneString", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_STRING) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddOneStringToPayload("hello0") - assert.Nil(t, err) - err = w.AddOneStringToPayload("hello1") - assert.Nil(t, err) - err = w.AddOneStringToPayload("hello2") - assert.Nil(t, err) - err = w.AddDataToPayload("hello3") - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, length, 4) - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_STRING, buffer) - assert.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 4) - str0, err := r.GetOneStringFromPayload(0) - assert.Nil(t, err) - assert.Equal(t, str0, "hello0") - str1, err := r.GetOneStringFromPayload(1) - assert.Nil(t, err) - assert.Equal(t, str1, "hello1") - str2, err := r.GetOneStringFromPayload(2) - assert.Nil(t, err) - assert.Equal(t, str2, "hello2") - str3, err := r.GetOneStringFromPayload(3) - assert.Nil(t, err) - assert.Equal(t, str3, "hello3") - - istr0, _, err := r.GetDataFromPayload(0) - str0 = istr0.(string) - assert.Nil(t, err) - assert.Equal(t, str0, "hello0") - - istr1, _, err := r.GetDataFromPayload(1) - str1 = istr1.(string) - assert.Nil(t, err) - assert.Equal(t, str1, "hello1") - - istr2, _, err := r.GetDataFromPayload(2) - str2 = istr2.(string) - assert.Nil(t, err) - assert.Equal(t, str2, "hello2") - - istr3, _, err := r.GetDataFromPayload(3) - str3 = istr3.(string) - assert.Nil(t, err) - assert.Equal(t, str3, "hello3") - - err = r.ReleasePayloadReader() - assert.Nil(t, err) - err = w.ReleasePayloadWriter() - assert.Nil(t, err) - }) - - t.Run("TestBinaryVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY) - require.Nil(t, err) - require.NotNil(t, w) - - in := make([]byte, 16) - for i := 0; i < 16; i++ { - in[i] = 1 - } - in2 := make([]byte, 8) - for i := 0; i < 8; i++ { - in2[i] = 1 - } - - err = w.AddBinaryVectorToPayload(in, 8) - assert.Nil(t, err) - err = w.AddDataToPayload(in2, 8) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 24, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 24) - - binVecs, dim, err := r.GetBinaryVectorFromPayload() - assert.Nil(t, err) - assert.Equal(t, 8, dim) - assert.Equal(t, 24, len(binVecs)) - fmt.Println(binVecs) - - ibinVecs, dim, err := r.GetDataFromPayload() - assert.Nil(t, err) - binVecs = ibinVecs.([]byte) - assert.Equal(t, 8, dim) - assert.Equal(t, 24, len(binVecs)) - defer r.ReleasePayloadReader() - }) - - t.Run("TestFloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT) - require.Nil(t, err) - require.NotNil(t, w) - - err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1) - assert.Nil(t, err) - err = w.AddDataToPayload([]float32{3.0, 4.0}, 1) - assert.Nil(t, err) - err = w.FinishPayloadWriter() - assert.Nil(t, err) - - length, err := w.GetPayloadLengthFromWriter() - assert.Nil(t, err) - assert.Equal(t, 4, length) - defer w.ReleasePayloadWriter() - - buffer, err := w.GetPayloadBufferFromWriter() - assert.Nil(t, err) - - r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer) - require.Nil(t, err) - length, err = r.GetPayloadLengthFromReader() - assert.Nil(t, err) - assert.Equal(t, length, 4) - - floatVecs, dim, err := r.GetFloatVectorFromPayload() - assert.Nil(t, err) - assert.Equal(t, 1, dim) - assert.Equal(t, 4, len(floatVecs)) - assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs) - - ifloatVecs, dim, err := r.GetDataFromPayload() - assert.Nil(t, err) - floatVecs = ifloatVecs.([]float32) - assert.Equal(t, 1, dim) - assert.Equal(t, 4, len(floatVecs)) - assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs) - defer r.ReleasePayloadReader() - }) -} diff --git a/internal/util/paramtable/paramtable.go b/internal/util/paramtable/paramtable.go index ef8d7f56f0..d16eaf0330 100644 --- a/internal/util/paramtable/paramtable.go +++ b/internal/util/paramtable/paramtable.go @@ -16,17 +16,13 @@ import ( "os" "path" "runtime" - "strconv" "strings" "github.com/spf13/cast" "github.com/spf13/viper" memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem" - "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) -type UniqueID = typeutil.UniqueID - type Base interface { Load(key string) (string, error) LoadRange(key, endKey string, limit int) ([]string, []string, error) @@ -42,18 +38,7 @@ type BaseTable struct { func (gp *BaseTable) Init() { gp.params = memkv.NewMemoryKV() - - err := gp.LoadYaml("milvus.yaml") - if err != nil { - panic(err) - } - - err = gp.LoadYaml("advanced/common.yaml") - if err != nil { - panic(err) - } - - err = gp.LoadYaml("advanced/channel.yaml") + err := gp.LoadYaml("config.yaml") if err != nil { panic(err) } @@ -161,140 +146,3 @@ func (gp *BaseTable) Remove(key string) error { func (gp *BaseTable) Save(key, value string) error { return gp.params.Save(strings.ToLower(key), value) } - -func (gp *BaseTable) ParseFloat(key string) float64 { - valueStr, err := gp.Load(key) - if err != nil { - panic(err) - } - value, err := strconv.ParseFloat(valueStr, 64) - if err != nil { - panic(err) - } - return value -} - -func (gp *BaseTable) ParseInt64(key string) int64 { - valueStr, err := gp.Load(key) - if err != nil { - panic(err) - } - value, err := strconv.Atoi(valueStr) - if err != nil { - panic(err) - } - return int64(value) -} - -func (gp *BaseTable) ParseInt32(key string) int32 { - valueStr, err := gp.Load(key) - if err != nil { - panic(err) - } - value, err := strconv.Atoi(valueStr) - if err != nil { - panic(err) - } - return int32(value) -} - -func (gp *BaseTable) ParseInt(key string) int { - valueStr, err := gp.Load(key) - if err != nil { - panic(err) - } - value, err := strconv.Atoi(valueStr) - if err != nil { - panic(err) - } - return value -} - -func (gp *BaseTable) WriteNodeIDList() []UniqueID { - proxyIDStr, err := gp.Load("nodeID.writeNodeIDList") - if err != nil { - panic(err) - } - var ret []UniqueID - proxyIDs := strings.Split(proxyIDStr, ",") - for _, i := range proxyIDs { - v, err := strconv.Atoi(i) - if err != nil { - log.Panicf("load write node id list error, %s", err.Error()) - } - ret = append(ret, UniqueID(v)) - } - return ret -} - -func (gp *BaseTable) ProxyIDList() []UniqueID { - proxyIDStr, err := gp.Load("nodeID.proxyIDList") - if err != nil { - panic(err) - } - var ret []UniqueID - proxyIDs := strings.Split(proxyIDStr, ",") - for _, i := range proxyIDs { - v, err := strconv.Atoi(i) - if err != nil { - log.Panicf("load proxy id list error, %s", err.Error()) - } - ret = append(ret, UniqueID(v)) - } - return ret -} - -func (gp *BaseTable) QueryNodeIDList() []UniqueID { - queryNodeIDStr, err := gp.Load("nodeID.queryNodeIDList") - if err != nil { - panic(err) - } - var ret []UniqueID - queryNodeIDs := strings.Split(queryNodeIDStr, ",") - for _, i := range queryNodeIDs { - v, err := strconv.Atoi(i) - if err != nil { - log.Panicf("load proxy id list error, %s", err.Error()) - } - ret = append(ret, UniqueID(v)) - } - return ret -} - -// package methods - -func ConvertRangeToIntRange(rangeStr, sep string) []int { - items := strings.Split(rangeStr, sep) - if len(items) != 2 { - panic("Illegal range ") - } - - startStr := items[0] - endStr := items[1] - start, err := strconv.Atoi(startStr) - if err != nil { - panic(err) - } - end, err := strconv.Atoi(endStr) - if err != nil { - panic(err) - } - - if start < 0 || end < 0 { - panic("Illegal range value") - } - if start > end { - panic("Illegal range value, start > end") - } - return []int{start, end} -} - -func ConvertRangeToIntSlice(rangeStr, sep string) []int { - rangeSlice := ConvertRangeToIntRange(rangeStr, sep) - start, end := rangeSlice[0], rangeSlice[1] - var ret []int - for i := start; i < end; i++ { - ret = append(ret, i) - } - return ret -} diff --git a/internal/util/paramtable/paramtable_test.go b/internal/util/paramtable/paramtable_test.go index c8caa90359..5957c72989 100644 --- a/internal/util/paramtable/paramtable_test.go +++ b/internal/util/paramtable/paramtable_test.go @@ -12,7 +12,6 @@ package paramtable import ( - "os" "testing" "github.com/stretchr/testify/assert" @@ -22,8 +21,6 @@ var Params = BaseTable{} func TestMain(m *testing.M) { Params.Init() - code := m.Run() - os.Exit(code) } //func TestMain @@ -58,13 +55,13 @@ func TestGlobalParamsTable_SaveAndLoad(t *testing.T) { } func TestGlobalParamsTable_LoadRange(t *testing.T) { - _ = Params.Save("xxxaab", "10") - _ = Params.Save("xxxfghz", "20") - _ = Params.Save("xxxbcde", "1.1") - _ = Params.Save("xxxabcd", "testSaveAndLoad") - _ = Params.Save("xxxzhi", "12") + _ = Params.Save("abc", "10") + _ = Params.Save("fghz", "20") + _ = Params.Save("bcde", "1.1") + _ = Params.Save("abcd", "testSaveAndLoad") + _ = Params.Save("zhi", "12") - keys, values, err := Params.LoadRange("xxxa", "xxxg", 10) + keys, values, err := Params.LoadRange("a", "g", 10) assert.Nil(t, err) assert.Equal(t, 4, len(keys)) assert.Equal(t, "10", values[0]) @@ -100,17 +97,24 @@ func TestGlobalParamsTable_Remove(t *testing.T) { } func TestGlobalParamsTable_LoadYaml(t *testing.T) { - err := Params.LoadYaml("milvus.yaml") + err := Params.LoadYaml("config.yaml") assert.Nil(t, err) - err = Params.LoadYaml("advanced/channel.yaml") - assert.Nil(t, err) + value1, err1 := Params.Load("etcd.address") + value2, err2 := Params.Load("pulsar.port") + value3, err3 := Params.Load("reader.topicend") + value4, err4 := Params.Load("proxy.pulsarTopics.readerTopicPrefix") + value5, err5 := Params.Load("proxy.network.address") - _, err = Params.Load("etcd.address") - assert.Nil(t, err) - _, err = Params.Load("pulsar.port") - assert.Nil(t, err) - _, err = Params.Load("msgChannel.channelRange.insert") - assert.Nil(t, err) + assert.Equal(t, value1, "localhost") + assert.Equal(t, value2, "6650") + assert.Equal(t, value3, "128") + assert.Equal(t, value4, "milvusReader") + assert.Equal(t, value5, "0.0.0.0") + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.Nil(t, err3) + assert.Nil(t, err4) + assert.Nil(t, err5) } diff --git a/scripts/cwrapper_build.sh b/scripts/cwrapper_build.sh deleted file mode 100755 index b27fe96f17..0000000000 --- a/scripts/cwrapper_build.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -SOURCE=${BASH_SOURCE[0]} -while [ -h $SOURCE ]; do # resolve $SOURCE until the file is no longer a symlink - DIR=$( cd -P $( dirname $SOURCE ) && pwd ) - SOURCE=$(readlink $SOURCE) - [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located -done -DIR=$( cd -P $( dirname $SOURCE ) && pwd ) -# DIR=${DIR}/../internal/storage/cwrapper - -CMAKE_BUILD=${DIR}/../cwrapper_build -OUTPUT_LIB=${DIR}/../internal/storage/cwrapper/output -SRC_DIR=${DIR}/../internal/storage/cwrapper - -if [ ! -d ${CMAKE_BUILD} ];then - mkdir ${CMAKE_BUILD} -fi - -if [ -d ${OUTPUT_LIB} ];then - rm -rf ${OUTPUT_LIB} -fi -mkdir ${OUTPUT_LIB} - -BUILD_TYPE="Debug" -GIT_ARROW_REPO="https://github.com/apache/arrow.git" -GIT_ARROW_TAG="apache-arrow-2.0.0" - -while getopts "a:b:t:h" arg; do - case $arg in - t) - BUILD_TYPE=$OPTARG # BUILD_TYPE - ;; - a) - GIT_ARROW_REPO=$OPTARG - ;; - b) - GIT_ARROW_TAG=$OPTARG - ;; - h) # help - echo "-t: build type(default: Debug) --a: arrow repo(default: https://github.com/apache/arrow.git) --b: arrow tag(default: apache-arrow-2.0.0) --h: help - " - exit 0 - ;; - ?) - echo "ERROR! unknown argument" - exit 1 - ;; - esac -done -echo "BUILD_TYPE: " $BUILD_TYPE -echo "GIT_ARROW_REPO: " $GIT_ARROW_REPO -echo "GIT_ARROW_TAG: " $GIT_ARROW_TAG - -pushd ${CMAKE_BUILD} -cmake -DCMAKE_INSTALL_PREFIX=${OUTPUT_LIB} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DGIT_ARROW_REPO=${GIT_ARROW_REPO} -DGIT_ARROW_TAG=${GIT_ARROW_TAG} ${SRC_DIR} && make && make install diff --git a/scripts/run_cpp_unittest.sh b/scripts/run_cpp_unittest.sh index e1bb3f96a7..1eb7c1db32 100755 --- a/scripts/run_cpp_unittest.sh +++ b/scripts/run_cpp_unittest.sh @@ -13,7 +13,6 @@ SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" MILVUS_CORE_DIR="${SCRIPTS_DIR}/../internal/core" CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/output" UNITTEST_DIRS=("${CORE_INSTALL_PREFIX}/unittest") -CWRAPPER_UNITTEST="${SCRIPTS_DIR}/../internal/storage/cwrapper/output/wrapper_test" # Currently core will install target lib to "core/output/lib" if [ -d "${CORE_INSTALL_PREFIX}/lib" ]; then @@ -43,12 +42,3 @@ for UNITTEST_DIR in "${UNITTEST_DIRS[@]}"; do # fi #done done - -# run cwrapper unittest -if [ -f ${CWRAPPER_UNITTEST} ];then - ${CWRAPPER_UNITTEST} - if [ $? -ne 0 ]; then - echo ${CWRAPPER_UNITTEST} " run failed" - exit 1 - fi -fi diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index a22617e263..f335a8922c 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -13,5 +13,5 @@ SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" # ignore Minio,S3 unittes MILVUS_DIR="${SCRIPTS_DIR}/../internal/" echo $MILVUS_DIR -go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/storage" "${MILVUS_DIR}/proxy/..." -failfast +go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." "${MILVUS_DIR}/proxy/..." -failfast #go test -cover "${MILVUS_DIR}/kv/..." "${MILVUS_DIR}/msgstream/..." "${MILVUS_DIR}/master/..." "${MILVUS_DIR}/querynode/..." -failfast diff --git a/tests/python/utils.py b/tests/python/utils.py index a3b4db8fdc..45841be022 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -260,8 +260,8 @@ def gen_binary_default_fields(auto_id=True): "fields": [ {"name": "int64", "type": DataType.INT64, "is_primary_key": not auto_id}, {"name": "float", "type": DataType.FLOAT}, - {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}, "indexes": [{"metric_type": "JACCARD"}]} - ], + {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}} + ], "segment_row_limit": default_segment_row_limit, "auto_id": auto_id }