diff --git a/Makefile b/Makefile index 9f62206990..7c599a2de1 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,7 @@ ifdef USE_ASAN use_asan =${USE_ASAN} endif -use_dynamic_simd = OFF +use_dynamic_simd = ON ifdef USE_DYNAMIC_SIMD use_dynamic_simd = ${USE_DYNAMIC_SIMD} endif diff --git a/configs/glog.conf b/configs/glog.conf index db36f674c1..c2874d892f 100644 --- a/configs/glog.conf +++ b/configs/glog.conf @@ -5,6 +5,11 @@ # `INFO``, ``WARNING``, ``ERROR``, and ``FATAL`` are 0, 1, 2, and 3 --minloglevel=0 --log_dir=/var/lib/milvus/logs/ +# using vlog to implement debug and trace log +# if set vmodule to 5, open debug level +# if set vmodule to 6, open trace level +# default 4, not open debug and trace +--v=4 # MB --max_log_size=200 ---stop_logging_if_full_disk=true \ No newline at end of file +--stop_logging_if_full_disk=true diff --git a/configs/milvus.yaml b/configs/milvus.yaml index b1724c6e8a..73a3c8c9cf 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -288,6 +288,7 @@ queryNode: # This parameter is only useful when enable-disk = true. # And this value should be a number greater than 1 and less than 32. chunkRows: 1024 # The number of vectors in a chunk. + exprEvalBatchSize: 8192 # The batch size for executor get next interimIndex: # build a vector temperate index for growing segment or binlog to accelerate search enableIndex: true nlist: 128 # segment index nlist diff --git a/internal/core/conanfile.py b/internal/core/conanfile.py index 1fb5a78e74..b879426bfd 100644 --- a/internal/core/conanfile.py +++ b/internal/core/conanfile.py @@ -36,7 +36,7 @@ class MilvusConan(ConanFile): "xz_utils/5.4.0", "prometheus-cpp/1.1.0", "re2/20230301", - "folly/2023.10.30.04@milvus/dev", + "folly/2023.10.30.05@milvus/dev", "google-cloud-cpp/2.5.0@milvus/dev", "opentelemetry-cpp/1.8.1.1@milvus/dev", "librdkafka/1.9.1", @@ -44,6 +44,9 @@ class MilvusConan(ConanFile): ) generators = ("cmake", "cmake_find_package") default_options = { + "libevent:shared": True, + "double-conversion:shared": True, + "folly:shared": True, "librdkafka:shared": True, "librdkafka:zstd": True, "librdkafka:ssl": True, diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index bdebcedf44..28480eb8a1 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -32,6 +32,7 @@ add_subdirectory( index ) add_subdirectory( query ) add_subdirectory( segcore ) add_subdirectory( indexbuilder ) +add_subdirectory( exec ) if(USE_DYNAMIC_SIMD) add_subdirectory( simd ) endif() diff --git a/internal/core/src/common/CMakeLists.txt b/internal/core/src/common/CMakeLists.txt index 5072728c20..e4ca81cb5f 100644 --- a/internal/core/src/common/CMakeLists.txt +++ b/internal/core/src/common/CMakeLists.txt @@ -22,6 +22,7 @@ set(COMMON_SRC Tracer.cpp IndexMeta.cpp EasyAssert.cpp + FieldData.cpp ) add_library(milvus_common SHARED ${COMMON_SRC}) diff --git a/internal/core/src/common/Common.cpp b/internal/core/src/common/Common.cpp index 63648e4334..c9bb37bd20 100644 --- a/internal/core/src/common/Common.cpp +++ b/internal/core/src/common/Common.cpp @@ -27,6 +27,7 @@ int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT = int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT = DEFAULT_LOW_PRIORITY_THREAD_CORE_COEFFICIENT; int CPU_NUM = DEFAULT_CPU_NUM; +int64_t EXEC_EVAL_EXPR_BATCH_SIZE = DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE; void SetIndexSliceSize(const int64_t size) { @@ -56,6 +57,13 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient) { << LOW_PRIORITY_THREAD_CORE_COEFFICIENT; } +void +SetDefaultExecEvalExprBatchSize(int64_t val) { + EXEC_EVAL_EXPR_BATCH_SIZE = val; + LOG_SEGCORE_INFO_ << "set default expr eval batch size: " + << EXEC_EVAL_EXPR_BATCH_SIZE; +} + void SetCpuNum(const int num) { CPU_NUM = num; diff --git a/internal/core/src/common/Common.h b/internal/core/src/common/Common.h index c4ba4c0829..784fdfad04 100644 --- a/internal/core/src/common/Common.h +++ b/internal/core/src/common/Common.h @@ -26,6 +26,7 @@ extern int64_t HIGH_PRIORITY_THREAD_CORE_COEFFICIENT; extern int64_t MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT; extern int64_t LOW_PRIORITY_THREAD_CORE_COEFFICIENT; extern int CPU_NUM; +extern int64_t EXEC_EVAL_EXPR_BATCH_SIZE; void SetIndexSliceSize(const int64_t size); @@ -42,4 +43,7 @@ SetLowPriorityThreadCoreCoefficient(const int64_t coefficient); void SetCpuNum(const int core); +void +SetDefaultExecEvalExprBatchSize(int64_t val); + } // namespace milvus diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index ded5ffcdc7..ecf2572291 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -39,6 +39,10 @@ const char INDEX_BUILD_ID_KEY[] = "indexBuildID"; const char INDEX_ROOT_PATH[] = "index_files"; const char RAWDATA_ROOT_PATH[] = "raw_datas"; +const char DEFAULT_PLANNODE_ID[] = "0"; +const char DEAFULT_QUERY_ID[] = "0"; +const char DEFAULT_TASK_ID[] = "0"; + const int64_t DEFAULT_FIELD_MAX_MEMORY_LIMIT = 64 << 20; // bytes const int64_t DEFAULT_HIGH_PRIORITY_THREAD_CORE_COEFFICIENT = 10; const int64_t DEFAULT_MIDDLE_PRIORITY_THREAD_CORE_COEFFICIENT = 5; @@ -48,6 +52,8 @@ const int64_t DEFAULT_INDEX_FILE_SLICE_SIZE = 4 << 20; // bytes const int DEFAULT_CPU_NUM = 1; +const int64_t DEFAULT_EXEC_EVAL_EXPR_BATCH_SIZE = 8192; + constexpr const char* RADIUS = knowhere::meta::RADIUS; constexpr const char* RANGE_FILTER = knowhere::meta::RANGE_FILTER; diff --git a/internal/core/src/common/Exception.h b/internal/core/src/common/Exception.h new file mode 100644 index 0000000000..68941ba567 --- /dev/null +++ b/internal/core/src/common/Exception.h @@ -0,0 +1,218 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace milvus { + +class NotImplementedException : public std::exception { + public: + explicit NotImplementedException(const std::string& msg) + : std::exception(), exception_message_(msg) { + } + const char* + what() const noexcept { + return exception_message_.c_str(); + } + virtual ~NotImplementedException() { + } + + private: + std::string exception_message_; +}; + +class NotSupportedDataTypeException : public std::exception { + public: + explicit NotSupportedDataTypeException(const std::string& msg) + : std::exception(), exception_message_(msg) { + } + const char* + what() const noexcept { + return exception_message_.c_str(); + } + virtual ~NotSupportedDataTypeException() { + } + + private: + std::string exception_message_; +}; + +class UnistdException : public std::runtime_error { + public: + explicit UnistdException(const std::string& msg) : std::runtime_error(msg) { + } + + virtual ~UnistdException() { + } +}; + +// Exceptions for storage module +class LocalChunkManagerException : public std::runtime_error { + public: + explicit LocalChunkManagerException(const std::string& msg) + : std::runtime_error(msg) { + } + virtual ~LocalChunkManagerException() { + } +}; + +class InvalidPathException : public LocalChunkManagerException { + public: + explicit InvalidPathException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~InvalidPathException() { + } +}; + +class OpenFileException : public LocalChunkManagerException { + public: + explicit OpenFileException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~OpenFileException() { + } +}; + +class CreateFileException : public LocalChunkManagerException { + public: + explicit CreateFileException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~CreateFileException() { + } +}; + +class ReadFileException : public LocalChunkManagerException { + public: + explicit ReadFileException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~ReadFileException() { + } +}; + +class WriteFileException : public LocalChunkManagerException { + public: + explicit WriteFileException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~WriteFileException() { + } +}; + +class PathAlreadyExistException : public LocalChunkManagerException { + public: + explicit PathAlreadyExistException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~PathAlreadyExistException() { + } +}; + +class DirNotExistException : public LocalChunkManagerException { + public: + explicit DirNotExistException(const std::string& msg) + : LocalChunkManagerException(msg) { + } + virtual ~DirNotExistException() { + } +}; + +class MinioException : public std::runtime_error { + public: + explicit MinioException(const std::string& msg) : std::runtime_error(msg) { + } + virtual ~MinioException() { + } +}; + +class InvalidBucketNameException : public MinioException { + public: + explicit InvalidBucketNameException(const std::string& msg) + : MinioException(msg) { + } + virtual ~InvalidBucketNameException() { + } +}; + +class ObjectNotExistException : public MinioException { + public: + explicit ObjectNotExistException(const std::string& msg) + : MinioException(msg) { + } + virtual ~ObjectNotExistException() { + } +}; +class S3ErrorException : public MinioException { + public: + explicit S3ErrorException(const std::string& msg) : MinioException(msg) { + } + virtual ~S3ErrorException() { + } +}; + +class DiskANNFileManagerException : public std::runtime_error { + public: + explicit DiskANNFileManagerException(const std::string& msg) + : std::runtime_error(msg) { + } + virtual ~DiskANNFileManagerException() { + } +}; + +class ArrowException : public std::runtime_error { + public: + explicit ArrowException(const std::string& msg) : std::runtime_error(msg) { + } + virtual ~ArrowException() { + } +}; + +// Exceptions for executor module +class ExecDriverException : public std::exception { + public: + explicit ExecDriverException(const std::string& msg) + : std::exception(), exception_message_(msg) { + } + const char* + what() const noexcept { + return exception_message_.c_str(); + } + virtual ~ExecDriverException() { + } + + private: + std::string exception_message_; +}; +class ExecOperatorException : public std::exception { + public: + explicit ExecOperatorException(const std::string& msg) + : std::exception(), exception_message_(msg) { + } + const char* + what() const noexcept { + return exception_message_.c_str(); + } + virtual ~ExecOperatorException() { + } + + private: + std::string exception_message_; +}; +} // namespace milvus diff --git a/internal/core/src/storage/FieldData.cpp b/internal/core/src/common/FieldData.cpp similarity index 84% rename from internal/core/src/storage/FieldData.cpp rename to internal/core/src/common/FieldData.cpp index 53c9878571..f85653e410 100644 --- a/internal/core/src/storage/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -14,15 +14,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "storage/FieldData.h" +#include "common/FieldData.h" + #include "arrow/array/array_binary.h" +#include "common/Array.h" #include "common/EasyAssert.h" +#include "common/Exception.h" +#include "common/FieldDataInterface.h" #include "common/Json.h" #include "simdjson/padded_string.h" -#include "common/Array.h" -#include "FieldDataInterface.h" -namespace milvus::storage { +namespace milvus { template void @@ -183,4 +185,33 @@ template class FieldDataImpl; template class FieldDataImpl; template class FieldDataImpl; -} // namespace milvus::storage +FieldDataPtr +InitScalarFieldData(const DataType& type, int64_t cap_rows) { + switch (type) { + case DataType::BOOL: + return std::make_shared>(type, cap_rows); + case DataType::INT8: + return std::make_shared>(type, cap_rows); + case DataType::INT16: + return std::make_shared>(type, cap_rows); + case DataType::INT32: + return std::make_shared>(type, cap_rows); + case DataType::INT64: + return std::make_shared>(type, cap_rows); + case DataType::FLOAT: + return std::make_shared>(type, cap_rows); + case DataType::DOUBLE: + return std::make_shared>(type, cap_rows); + case DataType::STRING: + case DataType::VARCHAR: + return std::make_shared>(type, cap_rows); + case DataType::JSON: + return std::make_shared>(type, cap_rows); + default: + throw NotSupportedDataTypeException( + "InitScalarFieldData not support data type " + + datatype_name(type)); + } +} + +} // namespace milvus diff --git a/internal/core/src/storage/FieldData.h b/internal/core/src/common/FieldData.h similarity index 88% rename from internal/core/src/storage/FieldData.h rename to internal/core/src/common/FieldData.h index 0a30006ab1..d5f89ab6bc 100644 --- a/internal/core/src/storage/FieldData.h +++ b/internal/core/src/common/FieldData.h @@ -21,10 +21,10 @@ #include -#include "storage/FieldDataInterface.h" +#include "common/FieldDataInterface.h" #include "common/Channel.h" -namespace milvus::storage { +namespace milvus { template class FieldData : public FieldDataImpl { @@ -34,6 +34,11 @@ class FieldData : public FieldDataImpl { : FieldDataImpl::FieldDataImpl( 1, data_type, buffered_num_rows) { } + static_assert(IsScalar || std::is_same_v); + explicit FieldData(DataType data_type, FixedVector&& inner_data) + : FieldDataImpl::FieldDataImpl( + 1, data_type, std::move(inner_data)) { + } }; template <> @@ -106,7 +111,10 @@ class FieldData : public FieldDataImpl { }; using FieldDataPtr = std::shared_ptr; -using FieldDataChannel = Channel; +using FieldDataChannel = Channel; using FieldDataChannelPtr = std::shared_ptr; -} // namespace milvus::storage \ No newline at end of file +FieldDataPtr +InitScalarFieldData(const DataType& type, int64_t cap_rows); + +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/storage/FieldDataInterface.h b/internal/core/src/common/FieldDataInterface.h similarity index 96% rename from internal/core/src/storage/FieldDataInterface.h rename to internal/core/src/common/FieldDataInterface.h index b2a490271f..ffc8045db3 100644 --- a/internal/core/src/storage/FieldDataInterface.h +++ b/internal/core/src/common/FieldDataInterface.h @@ -33,7 +33,7 @@ #include "common/EasyAssert.h" #include "common/Array.h" -namespace milvus::storage { +namespace milvus { using DataType = milvus::DataType; @@ -49,8 +49,8 @@ class FieldDataBase { virtual void FillFieldData(const std::shared_ptr array) = 0; - virtual const void* - Data() const = 0; + virtual void* + Data() = 0; virtual const void* RawValue(ssize_t offset) const = 0; @@ -109,6 +109,12 @@ class FieldDataImpl : public FieldDataBase { field_data_.resize(num_rows_ * dim_); } + explicit FieldDataImpl(size_t dim, DataType type, Chunk&& field_data) + : FieldDataBase(type), dim_(is_scalar ? 1 : dim) { + field_data_ = std::move(field_data); + num_rows_ = field_data.size() / dim; + } + void FillFieldData(const void* source, ssize_t element_count) override; @@ -126,8 +132,8 @@ class FieldDataImpl : public FieldDataBase { return "FieldDataImpl"; } - const void* - Data() const override { + void* + Data() override { return field_data_.data(); } @@ -332,4 +338,4 @@ class FieldDataArrayImpl : public FieldDataImpl { } }; -} // namespace milvus::storage +} // namespace milvus diff --git a/internal/core/src/common/Promise.h b/internal/core/src/common/Promise.h new file mode 100644 index 0000000000..2d919b97b5 --- /dev/null +++ b/internal/core/src/common/Promise.h @@ -0,0 +1,75 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "log/Log.h" + +namespace milvus { + +template +class MilvusPromise : public folly::Promise { + public: + MilvusPromise() : folly::Promise() { + } + + explicit MilvusPromise(const std::string& context) + : folly::Promise(), context_(context) { + } + + MilvusPromise(folly::futures::detail::EmptyConstruct, + const std::string& context) noexcept + : folly::Promise(folly::Promise::makeEmpty()), context_(context) { + } + + ~MilvusPromise() { + if (!this->isFulfilled()) { + LOG_SEGCORE_WARNING_ + << "PROMISE: Unfulfilled promise is being deleted. Context: " + << context_; + } + } + + explicit MilvusPromise(MilvusPromise&& other) + : folly::Promise(std::move(other)), + context_(std::move(other.context_)) { + } + + MilvusPromise& + operator=(MilvusPromise&& other) noexcept { + folly::Promise::operator=(std::move(other)); + context_ = std::move(other.context_); + return *this; + } + + static MilvusPromise + MakeEmpty(const std::string& context = "") noexcept { + return MilvusPromise(folly::futures::detail::EmptyConstruct{}, + context); + } + + private: + /// Optional parameter to understand where this promise was created. + std::string context_; +}; + +using ContinuePromise = MilvusPromise; +using ContinueFuture = folly::SemiFuture; + +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 2db86a0390..78bdc498ad 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -95,6 +95,10 @@ enum class DataType { ARRAY = 22, JSON = 23, + // Some special Data type, start from after 50 + // just for internal use now, may sync proto in future + ROW = 50, + VECTOR_BINARY = 100, VECTOR_FLOAT = 101, VECTOR_FLOAT16 = 102, @@ -182,8 +186,138 @@ using MayConstRef = std::conditional_t || const T&, T>; static_assert(std::is_same_v>); + +template +struct TypeTraits {}; + +template <> +struct TypeTraits { + static constexpr const char* Name = "NONE"; +}; +template <> +struct TypeTraits { + using NativeType = bool; + static constexpr DataType TypeKind = DataType::BOOL; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "BOOL"; +}; + +template <> +struct TypeTraits { + using NativeType = int8_t; + static constexpr DataType TypeKind = DataType::INT8; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "INT8"; +}; + +template <> +struct TypeTraits { + using NativeType = int16_t; + static constexpr DataType TypeKind = DataType::INT16; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "INT16"; +}; + +template <> +struct TypeTraits { + using NativeType = int32_t; + static constexpr DataType TypeKind = DataType::INT32; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "INT32"; +}; + +template <> +struct TypeTraits { + using NativeType = int32_t; + static constexpr DataType TypeKind = DataType::INT64; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "INT64"; +}; + +template <> +struct TypeTraits { + using NativeType = float; + static constexpr DataType TypeKind = DataType::FLOAT; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "FLOAT"; +}; + +template <> +struct TypeTraits { + using NativeType = double; + static constexpr DataType TypeKind = DataType::DOUBLE; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = true; + static constexpr const char* Name = "DOUBLE"; +}; + +template <> +struct TypeTraits { + using NativeType = std::string; + static constexpr DataType TypeKind = DataType::VARCHAR; + static constexpr bool IsPrimitiveType = true; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "VARCHAR"; +}; + +template <> +struct TypeTraits : public TypeTraits { + static constexpr DataType TypeKind = DataType::STRING; + static constexpr const char* Name = "STRING"; +}; + +template <> +struct TypeTraits { + using NativeType = void; + static constexpr DataType TypeKind = DataType::ARRAY; + static constexpr bool IsPrimitiveType = false; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "ARRAY"; +}; + +template <> +struct TypeTraits { + using NativeType = void; + static constexpr DataType TypeKind = DataType::JSON; + static constexpr bool IsPrimitiveType = false; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "JSON"; +}; + +template <> +struct TypeTraits { + using NativeType = void; + static constexpr DataType TypeKind = DataType::ROW; + static constexpr bool IsPrimitiveType = false; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "ROW"; +}; + +template <> +struct TypeTraits { + using NativeType = uint8_t; + static constexpr DataType TypeKind = DataType::VECTOR_BINARY; + static constexpr bool IsPrimitiveType = false; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "VECTOR_BINARY"; +}; + +template <> +struct TypeTraits { + using NativeType = float; + static constexpr DataType TypeKind = DataType::VECTOR_FLOAT; + static constexpr bool IsPrimitiveType = false; + static constexpr bool IsFixedWidth = false; + static constexpr const char* Name = "VECTOR_FLOAT"; +}; + } // namespace milvus - // template <> struct fmt::formatter : formatter { auto @@ -226,6 +360,9 @@ struct fmt::formatter : formatter { case milvus::DataType::JSON: name = "JSON"; break; + case milvus::DataType::ROW: + name = "ROW"; + break; case milvus::DataType::VECTOR_BINARY: name = "VECTOR_BINARY"; break; diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index a0166bd2df..886e5bb9b2 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -192,4 +192,17 @@ is_in_disk_list(const IndexType& index_type) { return is_in_list(index_type, DISK_INDEX_LIST); } +template +std::string +Join(const std::vector& items, const std::string& delimiter) { + std::stringstream ss; + for (size_t i = 0; i < items.size(); ++i) { + if (i > 0) { + ss << delimiter; + } + ss << items[i]; + } + return ss.str(); +} + } // namespace milvus diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h new file mode 100644 index 0000000000..0abaf9af60 --- /dev/null +++ b/internal/core/src/common/Vector.h @@ -0,0 +1,141 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "common/FieldData.h" + +namespace milvus { + +/** + * @brief base class for different type vector + * @todo implement full null value support + */ + +class BaseVector { + public: + BaseVector(DataType data_type, + size_t length, + std::optional null_count = std::nullopt) + : type_kind_(data_type), length_(length), null_count_(null_count) { + } + virtual ~BaseVector() = default; + + int64_t + size() { + return length_; + } + + DataType + type() { + return type_kind_; + } + + protected: + DataType type_kind_; + size_t length_; + std::optional null_count_; +}; + +using VectorPtr = std::shared_ptr; + +/** + * @brief Single vector for scalar types + * @todo using memory pool && buffer replace FieldData + */ +class ColumnVector final : public BaseVector { + public: + ColumnVector(DataType data_type, + size_t length, + std::optional null_count = std::nullopt) + : BaseVector(data_type, length, null_count) { + values_ = InitScalarFieldData(data_type, length); + } + + ColumnVector(FixedVector&& data) + : BaseVector(DataType::BOOL, data.size()) { + values_ = + std::make_shared>(DataType::BOOL, std::move(data)); + } + + virtual ~ColumnVector() override { + values_.reset(); + } + + void* + GetRawData() { + return values_->Data(); + } + + template + const As* + RawAsValues() const { + return reinterpret_cast(values_->Data()); + } + + private: + FieldDataPtr values_; +}; + +using ColumnVectorPtr = std::shared_ptr; + +/** + * @brief Multi vectors for scalar types + * mainly using it to pass internal result in segcore scalar engine system + */ +class RowVector : public BaseVector { + public: + RowVector(std::vector& data_types, + size_t length, + std::optional null_count = std::nullopt) + : BaseVector(DataType::ROW, length, null_count) { + for (auto& type : data_types) { + children_values_.emplace_back( + std::make_shared(type, length)); + } + } + + RowVector(const std::vector& children) + : BaseVector(DataType::ROW, 0) { + for (auto& child : children) { + children_values_.push_back(child); + if (child->size() > length_) { + length_ = child->size(); + } + } + } + + const std::vector& + childrens() { + return children_values_; + } + + VectorPtr + child(int index) { + assert(index < children_values_.size()); + return children_values_[index]; + } + + private: + std::vector children_values_; +}; + +using RowVectorPtr = std::shared_ptr; + +} // namespace milvus diff --git a/internal/core/src/common/init_c.cpp b/internal/core/src/common/init_c.cpp index 0f166d9422..91b32ca695 100644 --- a/internal/core/src/common/init_c.cpp +++ b/internal/core/src/common/init_c.cpp @@ -25,7 +25,7 @@ #include "common/Tracer.h" #include "log/Log.h" -std::once_flag flag1, flag2, flag3, flag4, flag5; +std::once_flag flag1, flag2, flag3, flag4, flag5, flag6; std::once_flag traceFlag; void @@ -70,6 +70,14 @@ InitCpuNum(const int value) { flag3, [](int value) { milvus::SetCpuNum(value); }, value); } +void +InitDefaultExprEvalBatchSize(int64_t val) { + std::call_once( + flag6, + [](int val) { milvus::SetDefaultExecEvalExprBatchSize(val); }, + val); +} + void InitTrace(CTraceConfig* config) { auto traceConfig = milvus::tracer::TraceConfig{config->exporter, diff --git a/internal/core/src/common/init_c.h b/internal/core/src/common/init_c.h index cc1e17cb28..ab872d4eef 100644 --- a/internal/core/src/common/init_c.h +++ b/internal/core/src/common/init_c.h @@ -36,6 +36,9 @@ InitMiddlePriorityThreadCoreCoefficient(const int64_t); void InitLowPriorityThreadCoreCoefficient(const int64_t); +void +InitDefaultExprEvalBatchSize(int64_t val); + void InitCpuNum(const int); diff --git a/internal/core/src/exec/CMakeLists.txt b/internal/core/src/exec/CMakeLists.txt new file mode 100644 index 0000000000..1573cf3e57 --- /dev/null +++ b/internal/core/src/exec/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License + +set(MILVUS_EXEC_SRCS + expression/Expr.cpp + expression/UnaryExpr.cpp + expression/ConjunctExpr.cpp + expression/LogicalUnaryExpr.cpp + expression/LogicalBinaryExpr.cpp + expression/TermExpr.cpp + expression/BinaryArithOpEvalRangeExpr.cpp + expression/BinaryRangeExpr.cpp + expression/AlwaysTrueExpr.cpp + expression/CompareExpr.cpp + expression/JsonContainsExpr.cpp + expression/ExistsExpr.cpp + operator/FilterBits.cpp + operator/Operator.cpp + Driver.cpp + Task.cpp + ) + +add_library(milvus_exec STATIC ${MILVUS_EXEC_SRCS}) +if(USE_DYNAMIC_SIMD) + target_link_libraries(milvus_exec milvus_common milvus_simd milvus-storage ${CONAN_LIBS}) +else() + target_link_libraries(milvus_exec milvus_common milvus-storage ${CONAN_LIBS}) +endif() diff --git a/internal/core/src/exec/Driver.cpp b/internal/core/src/exec/Driver.cpp new file mode 100644 index 0000000000..126df3cc20 --- /dev/null +++ b/internal/core/src/exec/Driver.cpp @@ -0,0 +1,355 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Driver.h" + +#include +#include + +#include "exec/operator/CallbackSink.h" +#include "exec/operator/FilterBits.h" +#include "exec/operator/Operator.h" +#include "exec/Task.h" + +#include "common/EasyAssert.h" + +namespace milvus { +namespace exec { + +std::atomic_uint64_t BlockingState::num_blocked_drivers_{0}; + +std::shared_ptr +DriverContext::GetQueryConfig() { + return task_->query_context()->query_config(); +} + +std::shared_ptr +DriverFactory::CreateDriver(std::unique_ptr ctx, + std::function num_drivers) { + auto driver = std::shared_ptr(new Driver()); + ctx->driver_ = driver.get(); + std::vector> operators; + operators.reserve(plannodes_.size()); + + for (size_t i = 0; i < plannodes_.size(); ++i) { + auto id = operators.size(); + auto plannode = plannodes_[i]; + if (auto filternode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back( + std::make_unique(id, ctx.get(), filternode)); + } + // TODO: add more operators + } + + if (consumer_supplier_) { + operators.push_back(consumer_supplier_(operators.size(), ctx.get())); + } + + driver->Init(std::move(ctx), std::move(operators)); + + return driver; +} + +void +Driver::Enqueue(std::shared_ptr driver) { + if (driver->closed_) { + return; + } + + driver->get_task()->query_context()->executor()->add( + [driver]() { Driver::Run(driver); }); +} + +void +Driver::Run(std::shared_ptr self) { + std::shared_ptr blocking_state; + RowVectorPtr result; + auto reason = self->RunInternal(self, blocking_state, result); + + AssertInfo(result == nullptr, + "The last operator (sink) must not produce any results."); + + if (reason == StopReason::kBlock) { + return; + } + + switch (reason) { + case StopReason::kBlock: + BlockingState::SetResume(blocking_state); + return; + case StopReason::kYield: + Enqueue(self); + case StopReason::kPause: + case StopReason::kTerminate: + case StopReason::kAlreadyTerminated: + case StopReason::kAtEnd: + return; + default: + AssertInfo(false, "Unhandled stop reason"); + } +} + +void +Driver::Init(std::unique_ptr ctx, + std::vector> operators) { + assert(ctx != nullptr); + ctx_ = std::move(ctx); + AssertInfo(operators.size() != 0, "operators in driver must not empty"); + operators_ = std::move(operators); + current_operator_index_ = operators_.size() - 1; +} + +void +Driver::Close() { + if (closed_) { + return; + } + + for (auto& op : operators_) { + op->Close(); + } + + closed_ = true; + + Task::RemoveDriver(ctx_->task_, this); +} + +RowVectorPtr +Driver::Next(std::shared_ptr& blocking_state) { + auto self = shared_from_this(); + + RowVectorPtr result; + auto stop = RunInternal(self, blocking_state, result); + + Assert(stop == StopReason::kBlock || stop == StopReason::kAtEnd || + stop == StopReason::kAlreadyTerminated); + return result; +} + +#define CALL_OPERATOR(call_func, operator, method_name) \ + try { \ + call_func; \ + } catch (SegcoreError & e) { \ + auto err_msg = fmt::format( \ + "Operator::{} failed for [Operator:{}, plan node id: " \ + "{}] : {}", \ + method_name, \ + operator->get_operator_type(), \ + operator->get_plannode_id(), \ + e.what()); \ + LOG_SEGCORE_ERROR_ << err_msg; \ + throw ExecOperatorException(err_msg); \ + } catch (std::exception & e) { \ + throw ExecOperatorException( \ + fmt::format("Operator::{} failed for [Operator:{}, plan node id: " \ + "{}] : {}", \ + method_name, \ + operator->get_operator_type(), \ + operator->get_plannode_id(), \ + e.what())); \ + } + +StopReason +Driver::RunInternal(std::shared_ptr& self, + std::shared_ptr& blocking_state, + RowVectorPtr& result) { + try { + int num_operators = operators_.size(); + ContinueFuture future; + + for (;;) { + for (int32_t i = num_operators - 1; i >= 0; --i) { + auto op = operators_[i].get(); + + current_operator_index_ = i; + CALL_OPERATOR( + blocking_reason_ = op->IsBlocked(&future), op, "IsBlocked"); + if (blocking_reason_ != BlockingReason::kNotBlocked) { + blocking_state = std::make_shared( + self, std::move(future), op, blocking_reason_); + return StopReason::kBlock; + } + Operator* next_op = nullptr; + + if (i < operators_.size() - 1) { + next_op = operators_[i + 1].get(); + CALL_OPERATOR( + blocking_reason_ = next_op->IsBlocked(&future), + next_op, + "IsBlocked"); + if (blocking_reason_ != BlockingReason::kNotBlocked) { + blocking_state = std::make_shared( + self, std::move(future), next_op, blocking_reason_); + return StopReason::kBlock; + } + + bool needs_input; + CALL_OPERATOR(needs_input = next_op->NeedInput(), + next_op, + "NeedInput"); + if (needs_input) { + RowVectorPtr result; + { + CALL_OPERATOR( + result = op->GetOutput(), op, "GetOutput"); + if (result) { + AssertInfo( + result->size() > 0, + fmt::format( + "GetOutput must return nullptr or " + "a non-empty vector: {}", + op->get_operator_type())); + } + } + if (result) { + CALL_OPERATOR( + next_op->AddInput(result), next_op, "AddInput"); + i += 2; + continue; + } else { + CALL_OPERATOR( + blocking_reason_ = op->IsBlocked(&future), + op, + "IsBlocked"); + if (blocking_reason_ != + BlockingReason::kNotBlocked) { + blocking_state = + std::make_shared( + self, + std::move(future), + next_op, + blocking_reason_); + return StopReason::kBlock; + } + if (op->IsFinished()) { + CALL_OPERATOR(next_op->NoMoreInput(), + next_op, + "NoMoreInput"); + break; + } + } + } + } else { + { + CALL_OPERATOR( + result = op->GetOutput(), op, "GetOutput"); + if (result) { + AssertInfo( + result->size() > 0, + fmt::format("GetOutput must return nullptr or " + "a non-empty vector: {}", + op->get_operator_type())); + blocking_reason_ = BlockingReason::kWaitForConsumer; + return StopReason::kBlock; + } + } + if (op->IsFinished()) { + Close(); + return StopReason::kAtEnd; + } + continue; + } + } + } + } catch (std::exception& e) { + get_task()->SetError(std::current_exception()); + return StopReason::kAlreadyTerminated; + } +} + +static bool +MustStartNewPipeline(std::shared_ptr plannode, + int source_id) { + //TODO: support LocalMerge and other shuffle + return source_id != 0; +} + +OperatorSupplier +MakeConsumerSupplier(ConsumerSupplier supplier) { + if (supplier) { + return [supplier](int32_t operator_id, DriverContext* ctx) { + return std::make_unique(operator_id, ctx, supplier()); + }; + } + return nullptr; +} + +uint32_t +MaxDrivers(const DriverFactory* factory, const QueryConfig& config) { + return 1; +} + +static void +SplitPlan(const std::shared_ptr& plannode, + std::vector>* current_plannodes, + const std::shared_ptr& consumer_node, + OperatorSupplier operator_supplier, + std::vector>* driver_factories) { + if (!current_plannodes) { + driver_factories->push_back(std::make_unique()); + current_plannodes = &driver_factories->back()->plannodes_; + driver_factories->back()->consumer_supplier_ = operator_supplier; + driver_factories->back()->consumer_node_ = consumer_node; + } + + auto sources = plannode->sources(); + if (sources.empty()) { + driver_factories->back()->is_input_driver_ = true; + } else { + for (int i = 0; i < sources.size(); ++i) { + SplitPlan( + sources[i], + MustStartNewPipeline(plannode, i) ? nullptr : current_plannodes, + plannode, + nullptr, + driver_factories); + } + } + current_plannodes->push_back(plannode); +} + +void +LocalPlanner::Plan( + const plan::PlanFragment& fragment, + ConsumerSupplier consumer_supplier, + std::vector>* driver_factories, + const QueryConfig& config, + uint32_t max_drivers) { + SplitPlan(fragment.plan_node_, + nullptr, + nullptr, + MakeConsumerSupplier(consumer_supplier), + driver_factories); + + (*driver_factories)[0]->is_output_driver_ = true; + + for (auto& factory : *driver_factories) { + factory->max_drivers_ = MaxDrivers(factory.get(), config); + factory->num_drivers_ = std::min(factory->max_drivers_, max_drivers); + + if (factory->is_group_execution_) { + factory->num_total_drivers_ = + factory->num_drivers_ * fragment.num_splitgroups_; + } else { + factory->num_total_drivers_ = factory->num_drivers_; + } + } +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/Driver.h b/internal/core/src/exec/Driver.h new file mode 100644 index 0000000000..12d1228a7c --- /dev/null +++ b/internal/core/src/exec/Driver.h @@ -0,0 +1,254 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "common/Types.h" +#include "common/Promise.h" +#include "exec/QueryContext.h" +#include "plan/PlanNode.h" + +namespace milvus { +namespace exec { + +enum class StopReason { + // Keep running. + kNone, + // Go off thread and do not schedule more activity. + kPause, + // Stop and free all. This is returned once and the thread that gets + // this value is responsible for freeing the state associated with + // the thread. Other threads will get kAlreadyTerminated after the + // first thread has received kTerminate. + kTerminate, + kAlreadyTerminated, + // Go off thread and then enqueue to the back of the runnable queue. + kYield, + // Must wait for external events. + kBlock, + // No more data to produce. + kAtEnd, + kAlreadyOnThread +}; + +enum class BlockingReason { + kNotBlocked, + kWaitForConsumer, + kWaitForSplit, + kWaitForExchange, + kWaitForJoinBuild, + /// For a build operator, it is blocked waiting for the probe operators to + /// finish probing before build the next hash table from one of the previously + /// spilled partition data. + /// For a probe operator, it is blocked waiting for all its peer probe + /// operators to finish probing before notifying the build operators to build + /// the next hash table from the previously spilled data. + kWaitForJoinProbe, + kWaitForMemory, + kWaitForConnector, + /// Build operator is blocked waiting for all its peers to stop to run group + /// spill on all of them. + kWaitForSpill, +}; + +class Driver; +class Operator; +class Task; +class BlockingState { + public: + BlockingState(std::shared_ptr driver, + ContinueFuture&& future, + Operator* op, + BlockingReason reason) + : driver_(std::move(driver_)), + future_(std::move(future)), + operator_(op), + reason_(reason) { + num_blocked_drivers_++; + } + + ~BlockingState() { + num_blocked_drivers_--; + } + + static void + SetResume(std::shared_ptr state) { + } + + Operator* + op() { + return operator_; + } + + BlockingReason + reason() { + return reason_; + } + + // Moves out the blocking future stored inside. Can be called only once. Used + // in single-threaded execution. + ContinueFuture + future() { + return std::move(future_); + } + + // Returns total number of drivers process wide that are currently in blocked + // state. + static uint64_t + get_num_blocked_drivers() { + return num_blocked_drivers_; + } + + private: + std::shared_ptr driver_; + ContinueFuture future_; + Operator* operator_; + BlockingReason reason_; + + static std::atomic_uint64_t num_blocked_drivers_; +}; + +struct DriverContext { + int driverid_; + int pipelineid_; + uint32_t split_groupid_; + uint32_t partitionid_; + + std::shared_ptr task_; + Driver* driver_; + + explicit DriverContext(std::shared_ptr task, + int driverid, + int pipilineid, + uint32_t split_group_id, + uint32_t partition_id) + : driverid_(driverid), + pipelineid_(pipilineid), + split_groupid_(split_group_id), + partitionid_(partition_id), + task_(task) { + } + + std::shared_ptr + GetQueryConfig(); +}; +using OperatorSupplier = std::function( + int32_t operatorid, DriverContext* ctx)>; + +struct DriverFactory { + std::vector> plannodes_; + OperatorSupplier consumer_supplier_; + // The (local) node that will consume results supplied by this pipeline. + // Can be null. We use that to determine the max drivers. + std::shared_ptr consumer_node_; + uint32_t max_drivers_; + uint32_t num_drivers_; + uint32_t num_total_drivers_; + + bool is_group_execution_; + bool is_input_driver_; + bool is_output_driver_; + + std::shared_ptr + CreateDriver(std::unique_ptr ctx, + // TODO: support exchange function + // std::shared_ptr exchange_client, + std::function num_driver); + + // TODO: support ditribution compute + bool + SupportSingleThreadExecution() const { + return true; + } +}; + +class Driver : public std::enable_shared_from_this { + public: + static void + Enqueue(std::shared_ptr instance); + + RowVectorPtr + Next(std::shared_ptr& blocking_state); + + DriverContext* + get_driver_context() const { + return ctx_.get(); + } + + const std::shared_ptr& + get_task() const { + return ctx_->task_; + } + + BlockingReason + GetBlockingReason() const { + return blocking_reason_; + } + + void + Init(std::unique_ptr driver_ctx, + std::vector> operators); + + private: + Driver() = default; + + void + EnqueueInternal() { + } + + static void + Run(std::shared_ptr self); + + StopReason + RunInternal(std::shared_ptr& self, + std::shared_ptr& blocking_state, + RowVectorPtr& result); + + void + Close(); + + std::unique_ptr ctx_; + + std::atomic_bool closed_{false}; + + std::vector> operators_; + + size_t current_operator_index_{0}; + + BlockingReason blocking_reason_{BlockingReason::kNotBlocked}; + + friend struct DriverFactory; +}; + +using Consumer = std::function; +using ConsumerSupplier = std::function; +class LocalPlanner { + public: + static void + Plan(const plan::PlanFragment& fragment, + ConsumerSupplier consumer_supplier, + std::vector>* driver_factories, + const QueryConfig& config, + uint32_t max_drivers); +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/QueryContext.h b/internal/core/src/exec/QueryContext.h new file mode 100644 index 0000000000..ad628ecd1c --- /dev/null +++ b/internal/core/src/exec/QueryContext.h @@ -0,0 +1,257 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "common/Common.h" +#include "common/Types.h" +#include "common/Exception.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +enum class ContextScope { GLOBAL = 0, SESSION = 1, QUERY = 2, Executor = 3 }; + +class BaseConfig { + public: + virtual folly::Optional + Get(const std::string& key) const = 0; + + template + folly::Optional + Get(const std::string& key) const { + auto val = Get(key); + if (val.hasValue()) { + return folly::to(val.value()); + } else { + return folly::none; + } + } + + template + T + Get(const std::string& key, const T& default_value) const { + auto val = Get(key); + if (val.hasValue()) { + return folly::to(val.value()); + } else { + return default_value; + } + } + + virtual bool + IsValueExists(const std::string& key) const = 0; + + virtual const std::unordered_map& + values() const { + throw NotImplementedException("method values() is not supported"); + } + + virtual ~BaseConfig() = default; +}; + +class MemConfig : public BaseConfig { + public: + explicit MemConfig( + const std::unordered_map& values) + : values_(values) { + } + + explicit MemConfig() : values_{} { + } + + explicit MemConfig(std::unordered_map&& values) + : values_(std::move(values)) { + } + + folly::Optional + Get(const std::string& key) const override { + folly::Optional val; + auto it = values_.find(key); + if (it != values_.end()) { + val = it->second; + } + return val; + } + + bool + IsValueExists(const std::string& key) const override { + return values_.find(key) != values_.end(); + } + + const std::unordered_map& + values() const override { + return values_; + } + + private: + std::unordered_map values_; +}; + +class QueryConfig : public MemConfig { + public: + // Whether to use the simplified expression evaluation path. False by default. + static constexpr const char* kExprEvalSimplified = + "expression.eval_simplified"; + + static constexpr const char* kExprEvalBatchSize = + "expression.eval_batch_size"; + + QueryConfig(const std::unordered_map& values) + : MemConfig(values) { + } + + QueryConfig() = default; + + bool + get_expr_eval_simplified() const { + return BaseConfig::Get(kExprEvalSimplified, false); + } + + int64_t + get_expr_batch_size() const { + return BaseConfig::Get(kExprEvalBatchSize, + EXEC_EVAL_EXPR_BATCH_SIZE); + } +}; + +class Context { + public: + explicit Context(ContextScope scope, + const std::shared_ptr parent = nullptr) + : scope_(scope), parent_(parent) { + } + + ContextScope + scope() const { + return scope_; + } + + std::shared_ptr + parent() const { + return parent_; + } + // // TODO: support dynamic update + // void + // set_config(const std::shared_ptr& config) { + // std::atomic_exchange(&config_, config); + // } + + // std::shared_ptr + // get_config() { + // return config_; + // } + + private: + ContextScope scope_; + std::shared_ptr parent_; + //std::shared_ptr config_; +}; + +class QueryContext : public Context { + public: + QueryContext(const std::string& query_id, + const milvus::segcore::SegmentInternalInterface* segment, + milvus::Timestamp timestamp, + std::shared_ptr query_config = + std::make_shared(), + folly::Executor* executor = nullptr, + std::unordered_map> + connector_configs = {}) + : Context(ContextScope::QUERY), + query_id_(query_id), + segment_(segment), + query_timestamp_(timestamp), + query_config_(query_config), + executor_(executor) { + } + + folly::Executor* + executor() const { + return executor_; + } + + const std::unordered_map>& + connector_configs() const { + return connector_configs_; + } + + std::shared_ptr + query_config() const { + return query_config_; + } + + std::string + query_id() const { + return query_id_; + } + + const milvus::segcore::SegmentInternalInterface* + get_segment() { + return segment_; + } + + milvus::Timestamp + get_query_timestamp() { + return query_timestamp_; + } + + private: + folly::Executor* executor_; + //folly::Executor::KeepAlive<> executor_keepalive_; + std::unordered_map> connector_configs_; + std::shared_ptr query_config_; + std::string query_id_; + + // current segment that query execute in + const milvus::segcore::SegmentInternalInterface* segment_; + // timestamp this query generate + milvus::Timestamp query_timestamp_; +}; + +// Represent the state of one thread of query execution. +// TODO: add more class member such as memory pool +class ExecContext : public Context { + public: + ExecContext(QueryContext* query_context) + : Context(ContextScope::Executor), query_context_(query_context) { + } + + QueryContext* + get_query_context() const { + return query_context_; + } + + std::shared_ptr + get_query_config() const { + return query_context_->query_config(); + } + + private: + QueryContext* query_context_; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/Task.cpp b/internal/core/src/exec/Task.cpp new file mode 100644 index 0000000000..2b643c71e8 --- /dev/null +++ b/internal/core/src/exec/Task.cpp @@ -0,0 +1,230 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Task.h" + +#include +#include +#include + +namespace milvus { +namespace exec { + +// Special group id to reflect the ungrouped execution. +constexpr uint32_t kUngroupedGroupId{std::numeric_limits::max()}; + +std::string +MakeUuid() { + return boost::lexical_cast(boost::uuids::random_generator()()); +} + +std::shared_ptr +Task::Create(const std::string& task_id, + plan::PlanFragment plan_fragment, + int destination, + std::shared_ptr query_context, + Consumer consumer, + std::function on_error) { + return Task::Create(task_id, + std::move(plan_fragment), + destination, + std::move(query_context), + (consumer ? [c = std::move(consumer)]() { return c; } + : ConsumerSupplier{}), + std::move(on_error)); +} + +std::shared_ptr +Task::Create(const std::string& task_id, + const plan::PlanFragment& plan_fragment, + int destination, + std::shared_ptr query_ctx, + ConsumerSupplier supplier, + std::function on_error) { + return std::shared_ptr(new Task(task_id, + std::move(plan_fragment), + destination, + std::move(query_ctx), + std::move(supplier), + std::move(on_error))); +} + +void +Task::SetError(const std::exception_ptr& exception) { + { + std::lock_guard l(mutex_); + if (!IsRunningLocked()) { + return; + } + + if (exception_ != nullptr) { + return; + } + exception_ = exception; + } + + Terminate(TaskState::kFailed); + + if (on_error_) { + on_error_(exception_); + } +} + +void +Task::SetError(const std::string& message) { + try { + throw std::runtime_error(message); + } catch (const std::runtime_error& e) { + SetError(std::current_exception()); + } +} + +void +Task::CreateDriversLocked(std::shared_ptr& self, + uint32_t split_group_id, + std::vector>& out) { + const bool is_group_execution_drivers = + (split_group_id != kUngroupedGroupId); + const auto num_pipelines = driver_factories_.size(); + + for (auto pipeline = 0; pipeline < num_pipelines; ++pipeline) { + auto& factory = driver_factories_[pipeline]; + + if (factory->is_group_execution_ != is_group_execution_drivers) { + continue; + } + + const uint32_t driverid_offset = + factory->num_drivers_ * + (is_group_execution_drivers ? split_group_id : 0); + + for (uint32_t partition_id = 0; partition_id < factory->num_drivers_; + ++partition_id) { + out.emplace_back(factory->CreateDriver( + std::make_unique(self, + driverid_offset + partition_id, + pipeline, + split_group_id, + partition_id), + [self](size_t i) { + return i < self->driver_factories_.size() + ? self->driver_factories_[i]->num_total_drivers_ + : 0; + })); + } + } +} + +RowVectorPtr +Task::Next(ContinueFuture* future) { + // NOTE: Task::Next is single-threaded execution + AssertInfo(plan_fragment_.execution_strategy_ == + plan::ExecutionStrategy::kUngrouped, + "Single-threaded execution supports only ungrouped execution"); + + AssertInfo(state_ == TaskState::kRunning, + "Task has already finished processing."); + + if (driver_factories_.empty()) { + AssertInfo( + consumer_supplier_ == nullptr, + "Single-threaded execution doesn't support delivering results to a " + "callback"); + + LocalPlanner::Plan(plan_fragment_, + nullptr, + &driver_factories_, + *query_context_->query_config(), + 1); + + for (const auto& factory : driver_factories_) { + assert(factory->SupportSingleThreadExecution()); + num_ungrouped_drivers_ += factory->num_drivers_; + num_total_drivers_ += factory->num_total_drivers_; + } + + auto self = shared_from_this(); + std::vector> drivers; + + drivers.reserve(num_ungrouped_drivers_); + CreateDriversLocked(self, kUngroupedGroupId, drivers); + + drivers_ = std::move(drivers); + } + + const auto num_drivers = drivers_.size(); + + std::vector futures; + futures.resize(num_drivers); + + for (;;) { + int runnable_drivers = 0; + int blocked_drivers = 0; + + for (auto i = 0; i < num_drivers; ++i) { + if (drivers_[i] == nullptr) { + continue; + } + + if (!futures[i].isReady()) { + ++blocked_drivers; + continue; + } + + ++runnable_drivers; + + std::shared_ptr blocking_state; + + auto result = drivers_[i]->Next(blocking_state); + + if (result) { + return result; + } + + if (blocking_state) { + futures[i] = blocking_state->future(); + } + + if (error()) { + std::rethrow_exception(error()); + } + } + + if (runnable_drivers == 0) { + if (blocked_drivers > 0) { + if (!future) { + throw ExecDriverException( + "Cannot make progress as all remaining drivers are " + "blocked and user are not expected to wait."); + } else { + std::vector not_ready_futures; + for (auto& continue_future : futures) { + if (!continue_future.isReady()) { + not_ready_futures.emplace_back( + std::move(continue_future)); + } + } + *future = + folly::collectAll(std::move(not_ready_futures)).unit(); + } + } + return nullptr; + } + } +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/Task.h b/internal/core/src/exec/Task.h new file mode 100644 index 0000000000..77396e01a4 --- /dev/null +++ b/internal/core/src/exec/Task.h @@ -0,0 +1,205 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common/Types.h" +#include "exec/Driver.h" +#include "exec/QueryContext.h" +#include "plan/PlanNode.h" + +namespace milvus { +namespace exec { + +enum class TaskState { kRunning, kFinished, kCanceled, kAborted, kFailed }; + +std::string +MakeUuid(); +class Task : public std::enable_shared_from_this { + public: + static std::shared_ptr + Create(const std::string& task_id, + plan::PlanFragment plan_fragment, + int destination, + std::shared_ptr query_context, + Consumer consumer = nullptr, + std::function on_error = nullptr); + + static std::shared_ptr + Create(const std::string& task_id, + const plan::PlanFragment& plan_fragment, + int destination, + std::shared_ptr query_ctx, + ConsumerSupplier supplier, + std::function on_error = nullptr); + + Task(const std::string& task_id, + plan::PlanFragment plan_fragment, + int destination, + std::shared_ptr query_ctx, + ConsumerSupplier consumer_supplier, + std::function on_error) + : uuid_{MakeUuid()}, + taskid_(task_id), + plan_fragment_(std::move(plan_fragment)), + destination_(destination), + query_context_(std::move(query_ctx)), + consumer_supplier_(std::move(consumer_supplier)), + on_error_(on_error) { + } + + ~Task() { + } + + const std::string& + uuid() const { + return uuid_; + } + + const std::string& + taskid() const { + return taskid_; + } + + const int + destination() const { + return destination_; + } + + const std::shared_ptr& + query_context() const { + return query_context_; + } + + static void + Start(std::shared_ptr self, + uint32_t max_drivers, + uint32_t concurrent_split_groups = 1); + + static void + RemoveDriver(std::shared_ptr self, Driver* instance) { + std::lock_guard lock(self->mutex_); + for (auto& driver_ptr : self->drivers_) { + if (driver_ptr.get() != instance) { + continue; + } + driver_ptr = nullptr; + self->DriverClosedLocked(); + } + } + + bool + SupportsSingleThreadedExecution() const { + if (consumer_supplier_) { + return false; + } + } + + RowVectorPtr + Next(ContinueFuture* future = nullptr); + + void + CreateDriversLocked(std::shared_ptr& self, + uint32_t split_groupid, + std::vector>& out); + + void + SetError(const std::exception_ptr& exception); + + void + SetError(const std::string& message); + + bool + IsRunning() const { + std::lock_guard l(mutex_); + return (state_ == TaskState::kRunning); + } + + bool + IsFinished() const { + std::lock_guard l(mutex_); + return (state_ == TaskState::kFinished); + } + + bool + IsRunningLocked() const { + return (state_ == TaskState::kRunning); + } + + bool + IsFinishedLocked() const { + return (state_ == TaskState::kFinished); + } + + void + Terminate(TaskState state) { + } + + std::exception_ptr + error() const { + std::lock_guard l(mutex_); + return exception_; + } + + void + DriverClosedLocked() { + if (IsRunningLocked()) { + --num_running_drivers_; + } + + num_finished_drivers_++; + } + + private: + std::string uuid_; + + std::string taskid_; + + plan::PlanFragment plan_fragment_; + + int destination_; + + std::shared_ptr query_context_; + + std::exception_ptr exception_ = nullptr; + + std::function on_error_; + + std::vector> driver_factories_; + + std::vector> drivers_; + + ConsumerSupplier consumer_supplier_; + + mutable std::mutex mutex_; + + TaskState state_ = TaskState::kRunning; + + uint32_t num_running_drivers_{0}; + + uint32_t num_total_drivers_{0}; + + uint32_t num_ungrouped_drivers_{0}; + + uint32_t num_finished_drivers_{0}; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp new file mode 100644 index 0000000000..ae9feb4767 --- /dev/null +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp @@ -0,0 +1,45 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "AlwaysTrueExpr.h" + +namespace milvus { +namespace exec { + +void +PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) { + int64_t real_batch_size = current_pos_ + batch_size_ >= num_rows_ + ? num_rows_ - current_pos_ + : batch_size_; + + if (real_batch_size == 0) { + result = nullptr; + return; + } + + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res_bool = (bool*)res_vec->GetRawData(); + for (size_t i = 0; i < real_batch_size; ++i) { + res_bool[i] = true; + } + + result = res_vec; + current_pos_ += real_batch_size; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.h b/internal/core/src/exec/expression/AlwaysTrueExpr.h new file mode 100644 index 0000000000..826660916e --- /dev/null +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.h @@ -0,0 +1,56 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyAlwaysTrueExpr : public Expr { + public: + PhyAlwaysTrueExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : Expr(DataType::BOOL, std::move(input), name), + expr_(expr), + batch_size_(batch_size) { + num_rows_ = segment->get_active_count(query_timestamp); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + std::shared_ptr expr_; + int64_t num_rows_; + int64_t current_pos_{0}; + int64_t batch_size_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp new file mode 100644 index 0000000000..d8126492cd --- /dev/null +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -0,0 +1,748 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "BinaryArithOpEvalRangeExpr.h" + +namespace milvus { +namespace exec { + +void +PhyBinaryArithOpEvalRangeExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (expr_->column_.data_type_) { + case DataType::BOOL: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT8: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT16: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT32: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT64: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::FLOAT: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::DOUBLE: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::JSON: { + auto value_type = expr_->value_.val_case(); + switch (value_type) { + case proto::plan::GenericValue::ValCase::kBoolVal: { + result = ExecRangeVisitorImplForJson(); + break; + } + case proto::plan::GenericValue::ValCase::kInt64Val: { + result = ExecRangeVisitorImplForJson(); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForJson(); + break; + } + default: { + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + value_type)); + } + } + break; + } + case DataType::ARRAY: { + auto value_type = expr_->value_.val_case(); + switch (value_type) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + result = ExecRangeVisitorImplForArray(); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForArray(); + break; + } + default: { + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + value_type)); + } + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +template +VectorPtr +PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + auto op_type = expr_->op_type_; + auto arith_type = expr_->arith_op_type_; + auto value = GetValueFromProto(expr_->value_); + auto right_operand = GetValueFromProto(expr_->right_operand_); + +#define BinaryArithRangeJSONCompare(cmp) \ + do { \ + for (size_t i = 0; i < size; ++i) { \ + auto x = data[i].template at(pointer); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = data[i].template at(pointer); \ + res[i] = !x.error() && (cmp); \ + continue; \ + } \ + res[i] = false; \ + continue; \ + } \ + res[i] = (cmp); \ + } \ + } while (false) + +#define BinaryArithRangeJSONCompareNotEqual(cmp) \ + do { \ + for (size_t i = 0; i < size; ++i) { \ + auto x = data[i].template at(pointer); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = data[i].template at(pointer); \ + res[i] = x.error() || (cmp); \ + continue; \ + } \ + res[i] = true; \ + continue; \ + } \ + res[i] = (cmp); \ + } \ + } while (false) + + auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data, + const int size, + bool* res, + ValueType val, + ValueType right_operand, + const std::string& pointer) { + switch (op_type) { + case proto::plan::OpType::Equal: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompare(x.value() + right_operand == + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompare(x.value() - right_operand == + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompare(x.value() * right_operand == + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompare(x.value() / right_operand == + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) == val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length == val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::NotEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeJSONCompareNotEqual( + x.value() + right_operand != val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeJSONCompareNotEqual( + x.value() - right_operand != val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeJSONCompareNotEqual( + x.value() * right_operand != val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeJSONCompareNotEqual( + x.value() / right_operand != val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeJSONCompareNotEqual( + static_cast( + fmod(x.value(), right_operand)) != val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + int array_length = 0; + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + res[i] = array_length != val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + default: + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for binary " + "arithmetic eval expr: {}", + op_type)); + } + }; + int64_t processed_size = ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + value, + right_operand, + pointer); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + int index = -1; + if (expr_->column_.nested_path_.size() > 0) { + index = std::stoi(expr_->column_.nested_path_[0]); + } + auto op_type = expr_->op_type_; + auto arith_type = expr_->arith_op_type_; + auto value = GetValueFromProto(expr_->value_); + auto right_operand = + arith_type != proto::plan::ArithOpType::ArrayLength + ? GetValueFromProto(expr_->right_operand_) + : ValueType(); + +#define BinaryArithRangeArrayCompare(cmp) \ + do { \ + for (size_t i = 0; i < size; ++i) { \ + if (index >= data[i].length()) { \ + res[i] = false; \ + continue; \ + } \ + auto value = data[i].get_data(index); \ + res[i] = (cmp); \ + } \ + } while (false) + + auto execute_sub_batch = [op_type, arith_type](const ArrayView* data, + const int size, + bool* res, + ValueType val, + ValueType right_operand, + int index) { + switch (op_type) { + case proto::plan::OpType::Equal: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand == + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand == + val); + + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand == + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand == + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast( + fmod(value, right_operand)) == val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() == val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::NotEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + BinaryArithRangeArrayCompare(value + right_operand != + val); + break; + } + case proto::plan::ArithOpType::Sub: { + BinaryArithRangeArrayCompare(value - right_operand != + val); + break; + } + case proto::plan::ArithOpType::Mul: { + BinaryArithRangeArrayCompare(value * right_operand != + val); + break; + } + case proto::plan::ArithOpType::Div: { + BinaryArithRangeArrayCompare(value / right_operand != + val); + break; + } + case proto::plan::ArithOpType::Mod: { + BinaryArithRangeArrayCompare( + static_cast( + fmod(value, right_operand)) != val); + break; + } + case proto::plan::ArithOpType::ArrayLength: { + for (size_t i = 0; i < size; ++i) { + res[i] = data[i].length() != val; + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + default: + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for binary " + "arithmetic eval expr: {}", + op_type)); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, value, right_operand, index); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImpl() { + if (is_index_mode_) { + return ExecRangeVisitorImplForIndex(); + } else { + return ExecRangeVisitorImplForData(); + } +} + +template +VectorPtr +PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() { + using Index = index::ScalarIndex; + typedef std::conditional_t && + !std::is_same_v, + int64_t, + T> + HighPrecisionType; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto value = GetValueFromProto(expr_->value_); + auto right_operand = + GetValueFromProto(expr_->right_operand_); + auto op_type = expr_->op_type_; + auto arith_type = expr_->arith_op_type_; + auto sub_batch_size = size_per_chunk_; + + auto execute_sub_batch = [op_type, arith_type, sub_batch_size]( + Index* index_ptr, + HighPrecisionType value, + HighPrecisionType right_operand) { + FixedVector res; + switch (op_type) { + case proto::plan::OpType::Equal: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::NotEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpIndexFunc + func; + res = std::move(func( + index_ptr, sub_batch_size, value, right_operand)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + default: + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for binary " + "arithmetic eval expr: {}", + op_type)); + } + return res; + }; + auto res = ProcessIndexChunks(execute_sub_batch, value, right_operand); + AssertInfo(res.size() == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size)); + return std::make_shared(std::move(res)); +} + +template +VectorPtr +PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { + typedef std::conditional_t && + !std::is_same_v, + int64_t, + T> + HighPrecisionType; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto value = GetValueFromProto(expr_->value_); + auto right_operand = + GetValueFromProto(expr_->right_operand_); + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + auto op_type = expr_->op_type_; + auto arith_type = expr_->arith_op_type_; + auto execute_sub_batch = [op_type, arith_type]( + const T* data, + const int size, + bool* res, + HighPrecisionType value, + HighPrecisionType right_operand) { + switch (op_type) { + case proto::plan::OpType::Equal: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + case proto::plan::OpType::NotEqual: { + switch (arith_type) { + case proto::plan::ArithOpType::Add: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Sub: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mul: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Div: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + case proto::plan::ArithOpType::Mod: { + ArithOpElementFunc + func; + func(data, size, value, right_operand, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arith type for binary " + "arithmetic eval expr: {}", + arith_type)); + } + break; + } + default: + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operator type for binary " + "arithmetic eval expr: {}", + op_type)); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, value, right_operand); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h new file mode 100644 index 0000000000..66a4dc08ce --- /dev/null +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -0,0 +1,213 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +template +struct ArithOpElementFunc { + typedef std::conditional_t && + !std::is_same_v, + int64_t, + T> + HighPrecisonType; + void + operator()(const T* src, + size_t size, + HighPrecisonType val, + HighPrecisonType right_operand, + bool* res) { + for (int i = 0; i < size; ++i) { + if constexpr (cmp_op == proto::plan::OpType::Equal) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) == val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::NotEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (src[i] + right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (src[i] - right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (src[i] * right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (src[i] / right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = (fmod(src[i], right_operand)) != val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } + } + } +}; + +template +struct ArithOpIndexFunc { + typedef std::conditional_t && + !std::is_same_v, + int64_t, + T> + HighPrecisonType; + using Index = index::ScalarIndex; + FixedVector + operator()(Index* index, + size_t size, + HighPrecisonType val, + HighPrecisonType right_operand) { + FixedVector res_vec(size); + bool* res = res_vec.data(); + for (size_t i = 0; i < size; ++i) { + if constexpr (cmp_op == proto::plan::OpType::Equal) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) == val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) == val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else if constexpr (cmp_op == proto::plan::OpType::NotEqual) { + if constexpr (arith_op == proto::plan::ArithOpType::Add) { + res[i] = (index->Reverse_Lookup(i) + right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Sub) { + res[i] = (index->Reverse_Lookup(i) - right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mul) { + res[i] = (index->Reverse_Lookup(i) * right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Div) { + res[i] = (index->Reverse_Lookup(i) / right_operand) != val; + } else if constexpr (arith_op == + proto::plan::ArithOpType::Mod) { + res[i] = + (fmod(index->Reverse_Lookup(i), right_operand)) != val; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } + } + return res_vec; + } +}; + +class PhyBinaryArithOpEvalRangeExpr : public SegmentExpr { + public: + PhyBinaryArithOpEvalRangeExpr( + const std::vector>& input, + const std::shared_ptr& + expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + template + VectorPtr + ExecRangeVisitorImpl(); + + template + VectorPtr + ExecRangeVisitorImplForIndex(); + + template + VectorPtr + ExecRangeVisitorImplForData(); + + template + VectorPtr + ExecRangeVisitorImplForJson(); + + template + VectorPtr + ExecRangeVisitorImplForArray(); + + private: + std::shared_ptr expr_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp new file mode 100644 index 0000000000..3db6f73026 --- /dev/null +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -0,0 +1,392 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "BinaryRangeExpr.h" + +#include "query/Utils.h" + +namespace milvus { +namespace exec { + +void +PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (expr_->column_.data_type_) { + case DataType::BOOL: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT8: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT16: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT32: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT64: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::FLOAT: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::DOUBLE: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::VARCHAR: { + if (segment_->type() == SegmentType::Growing) { + result = ExecRangeVisitorImpl(); + } else { + result = ExecRangeVisitorImpl(); + } + break; + } + case DataType::JSON: { + auto value_type = expr_->lower_val_.val_case(); + switch (value_type) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + result = ExecRangeVisitorImplForJson(); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForJson(); + break; + } + case proto::plan::GenericValue::ValCase::kStringVal: { + result = ExecRangeVisitorImplForJson(); + break; + } + default: { + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + value_type)); + } + } + break; + } + case DataType::ARRAY: { + auto value_type = expr_->lower_val_.val_case(); + switch (value_type) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + result = ExecRangeVisitorImplForArray(); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + result = ExecRangeVisitorImplForArray(); + break; + } + case proto::plan::GenericValue::ValCase::kStringVal: { + result = ExecRangeVisitorImplForArray(); + break; + } + default: { + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + value_type)); + } + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImpl() { + if (is_index_mode_) { + return ExecRangeVisitorImplForIndex(); + } else { + return ExecRangeVisitorImplForData(); + } +} + +template +ColumnVectorPtr +PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1, + HighPrecisionType& val2, + bool& lower_inclusive, + bool& upper_inclusive) { + lower_inclusive = expr_->lower_inclusive_; + upper_inclusive = expr_->upper_inclusive_; + val1 = GetValueFromProto(expr_->lower_val_); + val2 = GetValueFromProto(expr_->upper_val_); + auto get_next_overflow_batch = [this]() -> ColumnVectorPtr { + int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_ + ? num_rows_ - overflow_check_pos_ + : batch_size_; + overflow_check_pos_ += batch_size; + if (cached_overflow_res_ != nullptr && + cached_overflow_res_->size() == batch_size) { + return cached_overflow_res_; + } + auto res = std::make_shared(DataType::BOOL, batch_size); + return res; + }; + + if constexpr (std::is_integral_v && !std::is_same_v) { + if (milvus::query::gt_ub(val1)) { + return get_next_overflow_batch(); + } else if (milvus::query::lt_lb(val1)) { + val1 = std::numeric_limits::min(); + lower_inclusive = true; + } + + if (milvus::query::gt_ub(val2)) { + val2 = std::numeric_limits::max(); + upper_inclusive = true; + } else if (milvus::query::lt_lb(val2)) { + return get_next_overflow_batch(); + } + } + return nullptr; +} + +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + typedef std::conditional_t && + !std::is_same_v, + int64_t, + IndexInnerType> + HighPrecisionType; + + HighPrecisionType val1; + HighPrecisionType val2; + bool lower_inclusive = false; + bool upper_inclusive = false; + if (auto res = + PreCheckOverflow(val1, val2, lower_inclusive, upper_inclusive)) { + return res; + } + + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto execute_sub_batch = + [lower_inclusive, upper_inclusive]( + Index* index_ptr, HighPrecisionType val1, HighPrecisionType val2) { + BinaryRangeIndexFunc func; + return std::move( + func(index_ptr, val1, val2, lower_inclusive, upper_inclusive)); + }; + auto res = ProcessIndexChunks(execute_sub_batch, val1, val2); + AssertInfo(res.size() == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size)); + return std::make_shared(std::move(res)); +} + +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + typedef std::conditional_t && + !std::is_same_v, + int64_t, + IndexInnerType> + HighPrecisionType; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + HighPrecisionType val1; + HighPrecisionType val2; + bool lower_inclusive = false; + bool upper_inclusive = false; + if (auto res = + PreCheckOverflow(val1, val2, lower_inclusive, upper_inclusive)) { + return res; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto execute_sub_batch = [lower_inclusive, upper_inclusive]( + const T* data, + const int size, + bool* res, + HighPrecisionType val1, + HighPrecisionType val2) { + if (lower_inclusive && upper_inclusive) { + BinaryRangeElementFunc func; + func(val1, val2, data, size, res); + } else if (lower_inclusive && !upper_inclusive) { + BinaryRangeElementFunc func; + func(val1, val2, data, size, res); + } else if (!lower_inclusive && upper_inclusive) { + BinaryRangeElementFunc func; + func(val1, val2, data, size, res); + } else { + BinaryRangeElementFunc func; + func(val1, val2, data, size, res); + } + }; + auto skip_index_func = + [val1, val2, lower_inclusive, upper_inclusive]( + const SkipIndex& skip_index, FieldId field_id, int64_t chunk_id) { + if (lower_inclusive && upper_inclusive) { + return skip_index.CanSkipBinaryRange( + field_id, chunk_id, val1, val2, true, true); + } else if (lower_inclusive && !upper_inclusive) { + return skip_index.CanSkipBinaryRange( + field_id, chunk_id, val1, val2, true, false); + } else if (!lower_inclusive && upper_inclusive) { + return skip_index.CanSkipBinaryRange( + field_id, chunk_id, val1, val2, false, true); + } else { + return skip_index.CanSkipBinaryRange( + field_id, chunk_id, val1, val2, false, false); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, skip_index_func, res, val1, val2); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + bool lower_inclusive = expr_->lower_inclusive_; + bool upper_inclusive = expr_->upper_inclusive_; + ValueType val1 = GetValueFromProto(expr_->lower_val_); + ValueType val2 = GetValueFromProto(expr_->upper_val_); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + + auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer]( + const milvus::Json* data, + const int size, + bool* res, + ValueType val1, + ValueType val2) { + if (lower_inclusive && upper_inclusive) { + BinaryRangeElementFuncForJson func; + func(val1, val2, pointer, data, size, res); + } else if (lower_inclusive && !upper_inclusive) { + BinaryRangeElementFuncForJson func; + func(val1, val2, pointer, data, size, res); + } else if (!lower_inclusive && upper_inclusive) { + BinaryRangeElementFuncForJson func; + func(val1, val2, pointer, data, size, res); + } else { + BinaryRangeElementFuncForJson func; + func(val1, val2, pointer, data, size, res); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, val1, val2); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + bool lower_inclusive = expr_->lower_inclusive_; + bool upper_inclusive = expr_->upper_inclusive_; + ValueType val1 = GetValueFromProto(expr_->lower_val_); + ValueType val2 = GetValueFromProto(expr_->upper_val_); + int index = -1; + if (expr_->column_.nested_path_.size() > 0) { + index = std::stoi(expr_->column_.nested_path_[0]); + } + + auto execute_sub_batch = [lower_inclusive, upper_inclusive]( + const milvus::ArrayView* data, + const int size, + bool* res, + ValueType val1, + ValueType val2, + int index) { + if (lower_inclusive && upper_inclusive) { + BinaryRangeElementFuncForArray func; + func(val1, val2, index, data, size, res); + } else if (lower_inclusive && !upper_inclusive) { + BinaryRangeElementFuncForArray func; + func(val1, val2, index, data, size, res); + } else if (!lower_inclusive && upper_inclusive) { + BinaryRangeElementFuncForArray func; + func(val1, val2, index, data, size, res); + } else { + BinaryRangeElementFuncForArray func; + func(val1, val2, index, data, size, res); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, val1, val2, index); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h new file mode 100644 index 0000000000..e34b035238 --- /dev/null +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -0,0 +1,228 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +template +struct BinaryRangeElementFunc { + typedef std::conditional_t && + !std::is_same_v, + int64_t, + T> + HighPrecisionType; + void + operator()(T val1, T val2, const T* src, size_t n, bool* res) { + for (size_t i = 0; i < n; ++i) { + if constexpr (lower_inclusive && upper_inclusive) { + res[i] = val1 <= src[i] && src[i] <= val2; + } else if constexpr (lower_inclusive && !upper_inclusive) { + res[i] = val1 <= src[i] && src[i] < val2; + } else if constexpr (!lower_inclusive && upper_inclusive) { + res[i] = val1 < src[i] && src[i] <= val2; + } else { + res[i] = val1 < src[i] && src[i] < val2; + } + } + } +}; + +#define BinaryRangeJSONCompare(cmp) \ + do { \ + auto x = src[i].template at(pointer); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = src[i].template at(pointer); \ + if (!x.error()) { \ + auto value = x.value(); \ + res[i] = (cmp); \ + break; \ + } \ + } \ + res[i] = false; \ + break; \ + } \ + auto value = x.value(); \ + res[i] = (cmp); \ + } while (false) + +template +struct BinaryRangeElementFuncForJson { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + void + operator()(ValueType val1, + ValueType val2, + const std::string& pointer, + const milvus::Json* src, + size_t n, + bool* res) { + for (size_t i = 0; i < n; ++i) { + if constexpr (lower_inclusive && upper_inclusive) { + BinaryRangeJSONCompare(val1 <= value && value <= val2); + } else if constexpr (lower_inclusive && !upper_inclusive) { + BinaryRangeJSONCompare(val1 <= value && value < val2); + } else if constexpr (!lower_inclusive && upper_inclusive) { + BinaryRangeJSONCompare(val1 < value && value <= val2); + } else { + BinaryRangeJSONCompare(val1 < value && value < val2); + } + } + } +}; + +template +struct BinaryRangeElementFuncForArray { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + void + operator()(ValueType val1, + ValueType val2, + int index, + const milvus::ArrayView* src, + size_t n, + bool* res) { + for (size_t i = 0; i < n; ++i) { + if constexpr (lower_inclusive && upper_inclusive) { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto value = src[i].get_data(index); + res[i] = val1 <= value && value <= val2; + } else if constexpr (lower_inclusive && !upper_inclusive) { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto value = src[i].get_data(index); + res[i] = val1 <= value && value < val2; + } else if constexpr (!lower_inclusive && upper_inclusive) { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto value = src[i].get_data(index); + res[i] = val1 < value && value <= val2; + } else { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto value = src[i].get_data(index); + res[i] = val1 < value && value < val2; + } + } + } +}; + +template +struct BinaryRangeIndexFunc { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + typedef std::conditional_t && + !std::is_same_v, + int64_t, + IndexInnerType> + HighPrecisionType; + FixedVector + operator()(Index* index, + IndexInnerType val1, + IndexInnerType val2, + bool lower_inclusive, + bool upper_inclusive) { + return index->Range(val1, lower_inclusive, val2, upper_inclusive); + } +}; + +class PhyBinaryRangeFilterExpr : public SegmentExpr { + public: + PhyBinaryRangeFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + // Check overflow and cache result for performace + template < + typename T, + typename IndexInnerType = std:: + conditional_t, std::string, T>, + typename HighPrecisionType = std::conditional_t< + std::is_integral_v && !std::is_same_v, + int64_t, + IndexInnerType>> + ColumnVectorPtr + PreCheckOverflow(HighPrecisionType& val1, + HighPrecisionType& val2, + bool& lower_inclusive, + bool& upper_inclusive); + + template + VectorPtr + ExecRangeVisitorImpl(); + + template + VectorPtr + ExecRangeVisitorImplForIndex(); + + template + VectorPtr + ExecRangeVisitorImplForData(); + + template + VectorPtr + ExecRangeVisitorImplForJson(); + + template + VectorPtr + ExecRangeVisitorImplForArray(); + + private: + std::shared_ptr expr_; + ColumnVectorPtr cached_overflow_res_{nullptr}; + int64_t overflow_check_pos_{0}; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp new file mode 100644 index 0000000000..5e8ad71ca8 --- /dev/null +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -0,0 +1,319 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "CompareExpr.h" +#include "query/Relational.h" + +namespace milvus { +namespace exec { + +bool +PhyCompareFilterExpr::IsStringExpr() { + return expr_->left_data_type_ == DataType::VARCHAR || + expr_->right_data_type_ == DataType::VARCHAR; +} + +int64_t +PhyCompareFilterExpr::GetNextBatchSize() { + auto current_rows = + segment_->type() == SegmentType::Growing + ? current_chunk_id_ * size_per_chunk_ + current_chunk_pos_ + : current_chunk_pos_; + return current_rows + batch_size_ >= num_rows_ ? num_rows_ - current_rows + : batch_size_; +} + +template +ChunkDataAccessor +PhyCompareFilterExpr::GetChunkData(FieldId field_id, + int chunk_id, + int data_barrier) { + if (chunk_id >= data_barrier) { + auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const number { + return indexing.Reverse_Lookup(i); + }; + } + } + auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); + return [chunk_data](int i) -> const number { return chunk_data[i]; }; +} + +template <> +ChunkDataAccessor +PhyCompareFilterExpr::GetChunkData(FieldId field_id, + int chunk_id, + int data_barrier) { + if (chunk_id >= data_barrier) { + auto& indexing = + segment_->chunk_scalar_index(field_id, chunk_id); + if (indexing.HasRawData()) { + return [&indexing](int i) -> const std::string { + return indexing.Reverse_Lookup(i); + }; + } + } + if (segment_->type() == SegmentType::Growing) { + auto chunk_data = + segment_->chunk_data(field_id, chunk_id).data(); + return [chunk_data](int i) -> const number { return chunk_data[i]; }; + } else { + auto chunk_data = + segment_->chunk_data(field_id, chunk_id).data(); + return [chunk_data](int i) -> const number { + return std::string(chunk_data[i]); + }; + } +} + +ChunkDataAccessor +PhyCompareFilterExpr::GetChunkData(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier) { + switch (data_type) { + case DataType::BOOL: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT8: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT16: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT32: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::INT64: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::FLOAT: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::DOUBLE: + return GetChunkData(field_id, chunk_id, data_barrier); + case DataType::VARCHAR: { + return GetChunkData(field_id, chunk_id, data_barrier); + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", data_type)); + } +} + +template +VectorPtr +PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); + auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); + + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? num_rows_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + auto left = GetChunkData(expr_->left_data_type_, + expr_->left_field_id_, + chunk_id, + left_data_barrier); + auto right = GetChunkData(expr_->right_data_type_, + expr_->right_field_id_, + chunk_id, + right_data_barrier); + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + res[processed_rows++] = boost::apply_visitor( + milvus::query::Relational{}, left(i), right(i)); + + if (processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + return res_vec; + } + } + } + return res_vec; +} + +void +PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + // For segment both fields has no index, can use SIMD to speed up. + // Avoiding too much call stack that blocks SIMD. + if (!is_left_indexed_ && !is_right_indexed_ && !IsStringExpr()) { + result = ExecCompareExprDispatcherForBothDataSegment(); + return; + } + result = ExecCompareExprDispatcherForHybridSegment(); +} + +VectorPtr +PhyCompareFilterExpr::ExecCompareExprDispatcherForHybridSegment() { + switch (expr_->op_type_) { + case OpType::Equal: { + return ExecCompareExprDispatcher(std::equal_to<>{}); + } + case OpType::NotEqual: { + return ExecCompareExprDispatcher(std::not_equal_to<>{}); + } + case OpType::GreaterEqual: { + return ExecCompareExprDispatcher(std::greater_equal<>{}); + } + case OpType::GreaterThan: { + return ExecCompareExprDispatcher(std::greater<>{}); + } + case OpType::LessEqual: { + return ExecCompareExprDispatcher(std::less_equal<>{}); + } + case OpType::LessThan: { + return ExecCompareExprDispatcher(std::less<>{}); + } + case OpType::PrefixMatch: { + return ExecCompareExprDispatcher( + milvus::query::MatchOp{}); + } + // case OpType::PostfixMatch: { + // } + default: { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported optype: {}", expr_->op_type_)); + } + } +} + +VectorPtr +PhyCompareFilterExpr::ExecCompareExprDispatcherForBothDataSegment() { + switch (expr_->left_data_type_) { + case DataType::BOOL: + return ExecCompareLeftType(); + case DataType::INT8: + return ExecCompareLeftType(); + case DataType::INT16: + return ExecCompareLeftType(); + case DataType::INT32: + return ExecCompareLeftType(); + case DataType::INT64: + return ExecCompareLeftType(); + case DataType::FLOAT: + return ExecCompareLeftType(); + case DataType::DOUBLE: + return ExecCompareLeftType(); + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported left datatype:{} of compare expr", + expr_->left_data_type_)); + } +} + +template +VectorPtr +PhyCompareFilterExpr::ExecCompareLeftType() { + switch (expr_->right_data_type_) { + case DataType::BOOL: + return ExecCompareRightType(); + case DataType::INT8: + return ExecCompareRightType(); + case DataType::INT16: + return ExecCompareRightType(); + case DataType::INT32: + return ExecCompareRightType(); + case DataType::INT64: + return ExecCompareRightType(); + case DataType::FLOAT: + return ExecCompareRightType(); + case DataType::DOUBLE: + return ExecCompareRightType(); + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported right datatype:{} of compare expr", + expr_->right_data_type_)); + } +} + +template +VectorPtr +PhyCompareFilterExpr::ExecCompareRightType() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto expr_type = expr_->op_type_; + auto execute_sub_batch = [expr_type](const T* left, + const U* right, + const int size, + bool* res) { + switch (expr_type) { + case proto::plan::GreaterThan: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + case proto::plan::GreaterEqual: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + case proto::plan::LessThan: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + case proto::plan::LessEqual: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + case proto::plan::Equal: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + case proto::plan::NotEqual: { + CompareElementFunc func; + func(left, right, size, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported operator type for compare column expr: {}", + expr_type)); + } + }; + int64_t processed_size = + ProcessBothDataChunks(execute_sub_batch, res); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h new file mode 100644 index 0000000000..07c14a7bce --- /dev/null +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -0,0 +1,186 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +using number = boost::variant; +using ChunkDataAccessor = std::function; + +template +struct CompareElementFunc { + void + operator()(const T* left, const U* right, size_t size, bool* res) { + for (int i = 0; i < size; ++i) { + if constexpr (op == proto::plan::OpType::Equal) { + res[i] = left[i] == right[i]; + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res[i] = left[i] != right[i]; + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res[i] = left[i] > right[i]; + } else if constexpr (op == proto::plan::OpType::LessThan) { + res[i] = left[i] < right[i]; + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res[i] = left[i] >= right[i]; + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res[i] = left[i] <= right[i]; + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for CompareElementFunc", + op)); + } + } + } +}; + +class PhyCompareFilterExpr : public Expr { + public: + PhyCompareFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : Expr(DataType::BOOL, std::move(input), name), + left_field_(expr->left_field_id_), + right_field_(expr->right_field_id_), + segment_(segment), + query_timestamp_(query_timestamp), + batch_size_(batch_size), + expr_(expr) { + is_left_indexed_ = segment_->HasIndex(left_field_); + is_right_indexed_ = segment_->HasIndex(right_field_); + num_rows_ = segment_->get_active_count(query_timestamp_); + num_chunk_ = is_left_indexed_ + ? segment_->num_chunk_index(expr_->left_field_id_) + : segment_->num_chunk_data(expr_->left_field_id_); + size_per_chunk_ = segment_->size_per_chunk(); + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + int64_t + GetNextBatchSize(); + + bool + IsStringExpr(); + + template + ChunkDataAccessor + GetChunkData(FieldId field_id, int chunk_id, int data_barrier); + + template + int64_t + ProcessBothDataChunks(FUNC func, bool* res, ValTypes... values) { + int64_t processed_size = 0; + + for (size_t i = current_chunk_id_; i < num_chunk_; i++) { + auto left_chunk = segment_->chunk_data(left_field_, i); + auto right_chunk = segment_->chunk_data(right_field_, i); + auto data_pos = (i == current_chunk_id_) ? current_chunk_pos_ : 0; + auto size = (i == (num_chunk_ - 1)) + ? (segment_->type() == SegmentType::Growing + ? num_rows_ % size_per_chunk_ - data_pos + : num_rows_ - data_pos) + : size_per_chunk_ - data_pos; + + if (processed_size + size >= batch_size_) { + size = batch_size_ - processed_size; + } + + const T* left_data = left_chunk.data() + data_pos; + const U* right_data = right_chunk.data() + data_pos; + func(left_data, right_data, size, res + processed_size, values...); + processed_size += size; + + if (processed_size >= batch_size_) { + current_chunk_id_ = i; + current_chunk_pos_ = data_pos + size; + break; + } + } + + return processed_size; + } + + ChunkDataAccessor + GetChunkData(DataType data_type, + FieldId field_id, + int chunk_id, + int data_barrier); + + template + VectorPtr + ExecCompareExprDispatcher(OpType op); + + VectorPtr + ExecCompareExprDispatcherForHybridSegment(); + + VectorPtr + ExecCompareExprDispatcherForBothDataSegment(); + + template + VectorPtr + ExecCompareLeftType(); + + template + VectorPtr + ExecCompareRightType(); + + private: + const FieldId left_field_; + const FieldId right_field_; + bool is_left_indexed_; + bool is_right_indexed_; + int64_t num_rows_{0}; + int64_t num_chunk_{0}; + int64_t current_chunk_id_{0}; + int64_t current_chunk_pos_{0}; + int64_t size_per_chunk_{0}; + + const segcore::SegmentInternalInterface* segment_; + Timestamp query_timestamp_; + int64_t batch_size_; + std::shared_ptr expr_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ConjunctExpr.cpp b/internal/core/src/exec/expression/ConjunctExpr.cpp new file mode 100644 index 0000000000..1c1498b11e --- /dev/null +++ b/internal/core/src/exec/expression/ConjunctExpr.cpp @@ -0,0 +1,131 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ConjunctExpr.h" +#include "simd/hook.h" + +namespace milvus { +namespace exec { + +DataType +PhyConjunctFilterExpr::ResolveType(const std::vector& inputs) { + AssertInfo( + inputs.size() > 0, + fmt::format( + "Conjunct expressions expect at least one argument, received: {}", + inputs.size())); + + for (const auto& type : inputs) { + AssertInfo( + type == DataType::BOOL, + fmt::format("Conjunct expressions expect BOOLEAN, received: {}", + type)); + } + return DataType::BOOL; +} + +static bool +AllTrue(ColumnVectorPtr& vec) { + bool* data = static_cast(vec->GetRawData()); +#if defined(USE_DYNAMIC_SIMD) + return milvus::simd::all_true(data, vec->size()); +#else + for (int i = 0; i < vec->size(); ++i) { + if (!data[i]) { + return false; + } + } + return true; +#endif +} + +static void +AllSet(ColumnVectorPtr& vec) { + bool* data = static_cast(vec->GetRawData()); + for (int i = 0; i < vec->size(); ++i) { + data[i] = true; + } +} + +static void +AllReset(ColumnVectorPtr& vec) { + bool* data = static_cast(vec->GetRawData()); + for (int i = 0; i < vec->size(); ++i) { + data[i] = false; + } +} + +static bool +AllFalse(ColumnVectorPtr& vec) { + bool* data = static_cast(vec->GetRawData()); +#if defined(USE_DYNAMIC_SIMD) + return milvus::simd::all_false(data, vec->size()); +#else + for (int i = 0; i < vec->size(); ++i) { + if (data[i]) { + return false; + } + } + return true; +#endif +} + +int64_t +PhyConjunctFilterExpr::UpdateResult(ColumnVectorPtr& input_result, + EvalCtx& ctx, + ColumnVectorPtr& result) { + if (is_and_) { + ConjunctElementFunc func; + return func(input_result, result); + } else { + ConjunctElementFunc func; + return func(input_result, result); + } +} + +bool +PhyConjunctFilterExpr::CanSkipNextExprs(ColumnVectorPtr& vec) { + if ((is_and_ && AllFalse(vec)) || (!is_and_ && AllTrue(vec))) { + return true; + } + return false; +} + +void +PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + for (int i = 0; i < inputs_.size(); ++i) { + VectorPtr input_result; + inputs_[i]->Eval(context, input_result); + if (i == 0) { + result = input_result; + auto all_flat_result = GetColumnVector(result); + if (CanSkipNextExprs(all_flat_result)) { + return; + } + continue; + } + auto input_flat_result = GetColumnVector(input_result); + auto all_flat_result = GetColumnVector(result); + auto active_rows = + UpdateResult(input_flat_result, context, all_flat_result); + if (active_rows == 0) { + return; + } + } +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ConjunctExpr.h b/internal/core/src/exec/expression/ConjunctExpr.h new file mode 100644 index 0000000000..6027f56b60 --- /dev/null +++ b/internal/core/src/exec/expression/ConjunctExpr.h @@ -0,0 +1,89 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +template +struct ConjunctElementFunc { + int64_t + operator()(ColumnVectorPtr& input_result, ColumnVectorPtr& result) { + bool* input_data = static_cast(input_result->GetRawData()); + bool* res_data = static_cast(result->GetRawData()); + int64_t activate_rows = 0; + for (int i = 0; i < result->size(); ++i) { + if constexpr (is_and) { + res_data[i] &= input_data[i]; + if (res_data[i]) { + activate_rows++; + } + } else { + res_data[i] |= input_data[i]; + if (!res_data[i]) { + activate_rows++; + } + } + } + return activate_rows; + } +}; + +class PhyConjunctFilterExpr : public Expr { + public: + PhyConjunctFilterExpr(std::vector&& inputs, bool is_and) + : Expr(DataType::BOOL, std::move(inputs), is_and ? "and" : "or"), + is_and_(is_and) { + std::vector input_types; + input_types.reserve(inputs_.size()); + + std::transform(inputs_.begin(), + inputs_.end(), + std::back_inserter(input_types), + [](const ExprPtr& expr) { return expr->type(); }); + + ResolveType(input_types); + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + int64_t + UpdateResult(ColumnVectorPtr& input_result, + EvalCtx& ctx, + ColumnVectorPtr& result); + + static DataType + ResolveType(const std::vector& inputs); + + bool + CanSkipNextExprs(ColumnVectorPtr& vec); + // true if conjunction (and), false if disjunction (or). + bool is_and_; + std::vector input_order_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/EvalCtx.h b/internal/core/src/exec/expression/EvalCtx.h new file mode 100644 index 0000000000..69992945d1 --- /dev/null +++ b/internal/core/src/exec/expression/EvalCtx.h @@ -0,0 +1,62 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "common/Vector.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { + +class ExprSet; +class EvalCtx { + public: + EvalCtx(ExecContext* exec_ctx, ExprSet* expr_set, RowVector* row) + : exec_ctx_(exec_ctx), expr_set_(expr_set_), row_(row) { + assert(exec_ctx_ != nullptr); + assert(expr_set_ != nullptr); + // assert(row_ != nullptr); + } + + explicit EvalCtx(ExecContext* exec_ctx) + : exec_ctx_(exec_ctx), expr_set_(nullptr), row_(nullptr) { + } + + ExecContext* + get_exec_context() { + return exec_ctx_; + } + + std::shared_ptr + get_query_config() { + return exec_ctx_->get_query_config(); + } + + private: + ExecContext* exec_ctx_; + ExprSet* expr_set_; + RowVector* row_; + bool input_no_nulls_; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp new file mode 100644 index 0000000000..2226fde52f --- /dev/null +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -0,0 +1,72 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ExistsExpr.h" +#include "common/Json.h" + +namespace milvus { +namespace exec { + +void +PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (expr_->column_.data_type_) { + case DataType::JSON: { + if (is_index_mode_) { + PanicInfo(ExprInvalid, + "exists expr for json index mode not supportted"); + } + result = EvalJsonExistsForDataSegment(); + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +VectorPtr +PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + auto execute_sub_batch = [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer) { + for (int i = 0; i < size; ++i) { + res[i] = data[i].exist(pointer); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/ExistsExpr.h b/internal/core/src/exec/expression/ExistsExpr.h new file mode 100644 index 0000000000..5f5f110126 --- /dev/null +++ b/internal/core/src/exec/expression/ExistsExpr.h @@ -0,0 +1,66 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +template +struct ExistsElementFunc { + void + operator()(const T* src, size_t size, T val, bool* res) { + } +}; + +class PhyExistsFilterExpr : public SegmentExpr { + public: + PhyExistsFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + VectorPtr + EvalJsonExistsForDataSegment(); + + private: + std::shared_ptr expr_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp new file mode 100644 index 0000000000..81d026022b --- /dev/null +++ b/internal/core/src/exec/expression/Expr.cpp @@ -0,0 +1,255 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Expr.h" + +#include "exec/expression/AlwaysTrueExpr.h" +#include "exec/expression/BinaryArithOpEvalRangeExpr.h" +#include "exec/expression/BinaryRangeExpr.h" +#include "exec/expression/CompareExpr.h" +#include "exec/expression/ConjunctExpr.h" +#include "exec/expression/ExistsExpr.h" +#include "exec/expression/JsonContainsExpr.h" +#include "exec/expression/LogicalBinaryExpr.h" +#include "exec/expression/LogicalUnaryExpr.h" +#include "exec/expression/TermExpr.h" +#include "exec/expression/UnaryExpr.h" +namespace milvus { +namespace exec { + +void +ExprSet::Eval(int32_t begin, + int32_t end, + bool initialize, + EvalCtx& context, + std::vector& results) { + results.resize(exprs_.size()); + + for (size_t i = begin; i < end; ++i) { + exprs_[i]->Eval(context, results[i]); + } +} + +std::vector +CompileExpressions(const std::vector& sources, + ExecContext* context, + const std::unordered_set& flatten_candidate, + bool enable_constant_folding) { + std::vector> exprs; + exprs.reserve(sources.size()); + + for (auto& source : sources) { + exprs.emplace_back(CompileExpression(source, + context->get_query_context(), + flatten_candidate, + enable_constant_folding)); + } + return exprs; +} + +static std::optional +ShouldFlatten(const expr::TypedExprPtr& expr, + const std::unordered_set& flat_candidates = {}) { + if (auto call = + std::dynamic_pointer_cast(expr)) { + if (call->op_type_ == expr::LogicalBinaryExpr::OpType::And || + call->op_type_ == expr::LogicalBinaryExpr::OpType::Or) { + return call->name(); + } + } + return std::nullopt; +} + +static bool +IsCall(const expr::TypedExprPtr& expr, const std::string& name) { + if (auto call = + std::dynamic_pointer_cast(expr)) { + return call->name() == name; + } + return false; +} + +static bool +AllInputTypeEqual(const expr::TypedExprPtr& expr) { + const auto& inputs = expr->inputs(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs[0]->type() != inputs[i]->type()) { + return false; + } + } + return true; +} + +static void +FlattenInput(const expr::TypedExprPtr& input, + const std::string& flatten_call, + std::vector& flat) { + if (IsCall(input, flatten_call) && AllInputTypeEqual(input)) { + for (auto& child : input->inputs()) { + FlattenInput(child, flatten_call, flat); + } + } else { + flat.emplace_back(input); + } +} + +std::vector +CompileInputs(const expr::TypedExprPtr& expr, + QueryContext* context, + const std::unordered_set& flatten_cadidates) { + std::vector compiled_inputs; + auto flatten = ShouldFlatten(expr); + for (auto& input : expr->inputs()) { + if (dynamic_cast(input.get())) { + AssertInfo( + dynamic_cast(expr.get()), + "An InputReference can only occur under a FieldReference"); + } else { + if (flatten.has_value()) { + std::vector flat_exprs; + FlattenInput(input, flatten.value(), flat_exprs); + for (auto& input : flat_exprs) { + compiled_inputs.push_back(CompileExpression( + input, context, flatten_cadidates, false)); + } + } else { + compiled_inputs.push_back(CompileExpression( + input, context, flatten_cadidates, false)); + } + } + } + return compiled_inputs; +} + +ExprPtr +CompileExpression(const expr::TypedExprPtr& expr, + QueryContext* context, + const std::unordered_set& flatten_candidates, + bool enable_constant_folding) { + ExprPtr result; + + auto result_type = expr->type(); + auto compiled_inputs = CompileInputs(expr, context, flatten_candidates); + + auto GetTypes = [](const std::vector& exprs) { + std::vector types; + for (auto& expr : exprs) { + types.push_back(expr->type()); + } + return types; + }; + auto input_types = GetTypes(compiled_inputs); + + if (auto call = dynamic_cast(expr.get())) { + // TODO: support function register and search mode + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::UnaryRangeFilterExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyUnaryRangeFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::LogicalUnaryExpr>(expr)) { + result = std::make_shared( + compiled_inputs, casted_expr, "PhyLogicalUnaryExpr"); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::TermFilterExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyTermFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::LogicalBinaryExpr>(expr)) { + if (casted_expr->op_type_ == + milvus::expr::LogicalBinaryExpr::OpType::And || + casted_expr->op_type_ == + milvus::expr::LogicalBinaryExpr::OpType::Or) { + result = std::make_shared( + std::move(compiled_inputs), + casted_expr->op_type_ == + milvus::expr::LogicalBinaryExpr::OpType::And); + } else { + result = std::make_shared( + compiled_inputs, casted_expr, "PhyLogicalBinaryExpr"); + } + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::BinaryRangeFilterExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyBinaryRangeFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::AlwaysTrueExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyAlwaysTrueExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::BinaryArithOpEvalRangeExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyBinaryArithOpEvalRangeExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = + std::dynamic_pointer_cast( + expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyCompareFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = + std::dynamic_pointer_cast( + expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyExistsFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } else if (auto casted_expr = std::dynamic_pointer_cast< + const milvus::expr::JsonContainsExpr>(expr)) { + result = std::make_shared( + compiled_inputs, + casted_expr, + "PhyJsonContainsFilterExpr", + context->get_segment(), + context->get_query_timestamp(), + context->query_config()->get_expr_batch_size()); + } + return result; +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h new file mode 100644 index 0000000000..60555f55e2 --- /dev/null +++ b/internal/core/src/exec/expression/Expr.h @@ -0,0 +1,324 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "common/Types.h" +#include "exec/expression/EvalCtx.h" +#include "exec/expression/VectorFunction.h" +#include "exec/expression/Utils.h" +#include "exec/QueryContext.h" +#include "expr/ITypeExpr.h" +#include "query/PlanProto.h" + +namespace milvus { +namespace exec { + +class Expr { + public: + Expr(DataType type, + const std::vector>&& inputs, + const std::string& name) + : type_(type), + inputs_(std::move(inputs)), + name_(name), + vector_func_(nullptr) { + } + + Expr(DataType type, + const std::vector>&& inputs, + std::shared_ptr vec_func, + const std::string& name) + : type_(type), + inputs_(std::move(inputs)), + name_(name), + vector_func_(vec_func) { + } + virtual ~Expr() = default; + + const DataType& + type() const { + return type_; + } + + std::string + get_name() { + return name_; + } + + virtual void + Eval(EvalCtx& context, VectorPtr& result) { + } + + protected: + DataType type_; + const std::vector> inputs_; + std::string name_; + std::shared_ptr vector_func_; +}; + +using ExprPtr = std::shared_ptr; + +using SkipFunc = bool (*)(const milvus::SkipIndex&, FieldId, int); + +class SegmentExpr : public Expr { + public: + SegmentExpr(const std::vector&& input, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + const FieldId& field_id, + Timestamp query_timestamp, + int64_t batch_size) + : Expr(DataType::BOOL, std::move(input), name), + segment_(segment), + field_id_(field_id), + query_timestamp_(query_timestamp), + batch_size_(batch_size) { + num_rows_ = segment_->get_active_count(query_timestamp_); + size_per_chunk_ = segment_->size_per_chunk(); + AssertInfo( + batch_size_ > 0, + fmt::format("expr batch size should greater than zero, but now: {}", + batch_size_)); + InitSegmentExpr(); + } + + void + InitSegmentExpr() { + auto& schema = segment_->get_schema(); + auto& field_meta = schema[field_id_]; + + if (schema.get_primary_field_id().has_value() && + schema.get_primary_field_id().value() == field_id_ && + IsPrimaryKeyDataType(field_meta.get_data_type())) { + is_pk_field_ = true; + pk_type_ = field_meta.get_data_type(); + } + + is_index_mode_ = segment_->HasIndex(field_id_); + if (is_index_mode_) { + num_index_chunk_ = segment_->num_chunk_index(field_id_); + } else { + num_data_chunk_ = segment_->num_chunk_data(field_id_); + } + } + + int64_t + GetNextBatchSize() { + auto current_chunk = + is_index_mode_ ? current_index_chunk_ : current_data_chunk_; + auto current_chunk_pos = + is_index_mode_ ? current_index_chunk_pos_ : current_data_chunk_pos_; + auto current_rows = current_chunk * size_per_chunk_ + current_chunk_pos; + return current_rows + batch_size_ >= num_rows_ + ? num_rows_ - current_rows + : batch_size_; + } + + template + int64_t + ProcessDataChunks( + FUNC func, + std::function skip_func, + bool* res, + ValTypes... values) { + int64_t processed_size = 0; + + for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { + auto data_pos = + (i == current_data_chunk_) ? current_data_chunk_pos_ : 0; + auto size = (i == (num_data_chunk_ - 1)) + ? (segment_->type() == SegmentType::Growing + ? num_rows_ % size_per_chunk_ - data_pos + : num_rows_ - data_pos) + : size_per_chunk_ - data_pos; + + size = std::min(size, batch_size_ - processed_size); + + auto& skip_index = segment_->GetSkipIndex(); + if (!skip_func || !skip_func(skip_index, field_id_, i)) { + auto chunk = segment_->chunk_data(field_id_, i); + const T* data = chunk.data() + data_pos; + func(data, size, res + processed_size, values...); + } + + processed_size += size; + if (processed_size >= batch_size_) { + current_data_chunk_ = i; + current_data_chunk_pos_ = data_pos + size; + break; + } + } + + return processed_size; + } + + int + ProcessIndexOneChunk(FixedVector& result, + size_t chunk_id, + const FixedVector& chunk_res, + int processed_rows) { + auto data_pos = + chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; + auto size = std::min( + std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows), + int64_t(chunk_res.size())); + + result.insert(result.end(), + chunk_res.begin() + data_pos, + chunk_res.begin() + data_pos + size); + return size; + } + + template + FixedVector + ProcessIndexChunks(FUNC func, ValTypes... values) { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + FixedVector result; + int processed_rows = 0; + + for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { + // This cache result help getting result for every batch loop. + // It avoids indexing execute for evevy batch because indexing + // executing costs quite much time. + if (cached_index_chunk_id_ != i) { + const Index& index = + segment_->chunk_scalar_index(field_id_, i); + auto* index_ptr = const_cast(&index); + cached_index_chunk_res_ = std::move(func(index_ptr, values...)); + cached_index_chunk_id_ = i; + } + + auto size = ProcessIndexOneChunk( + result, i, cached_index_chunk_res_, processed_rows); + + if (processed_rows + size >= batch_size_) { + current_index_chunk_ = i; + current_index_chunk_pos_ = i == current_index_chunk_ + ? current_index_chunk_pos_ + size + : size; + break; + } + processed_rows += size; + } + + return result; + } + + protected: + const segcore::SegmentInternalInterface* segment_; + const FieldId field_id_; + bool is_pk_field_{false}; + DataType pk_type_; + Timestamp query_timestamp_; + int64_t batch_size_; + + // State indicate position that expr computing at + // because expr maybe called for every batch. + bool is_index_mode_{false}; + bool is_data_mode_{false}; + + int64_t num_rows_{0}; + int64_t num_data_chunk_{0}; + int64_t num_index_chunk_{0}; + int64_t current_data_chunk_{0}; + int64_t current_data_chunk_pos_{0}; + int64_t current_index_chunk_{0}; + int64_t current_index_chunk_pos_{0}; + int64_t size_per_chunk_{0}; + + // Cache for index scan to avoid search index every batch + int64_t cached_index_chunk_id_{-1}; + FixedVector cached_index_chunk_res_{}; +}; + +std::vector +CompileExpressions(const std::vector& logical_exprs, + ExecContext* context, + const std::unordered_set& flatten_cadidates = + std::unordered_set(), + bool enable_constant_folding = false); + +std::vector +CompileInputs(const expr::TypedExprPtr& expr, + QueryContext* config, + const std::unordered_set& flatten_cadidates); + +ExprPtr +CompileExpression(const expr::TypedExprPtr& expr, + QueryContext* context, + const std::unordered_set& flatten_cadidates, + bool enable_constant_folding); + +class ExprSet { + public: + explicit ExprSet(const std::vector& logical_exprs, + ExecContext* exec_ctx) { + exprs_ = CompileExpressions(logical_exprs, exec_ctx); + } + + virtual ~ExprSet() = default; + + void + Eval(EvalCtx& ctx, std::vector& results) { + Eval(0, exprs_.size(), true, ctx, results); + } + + virtual void + Eval(int32_t begin, + int32_t end, + bool initialize, + EvalCtx& ctx, + std::vector& result); + + void + Clear() { + exprs_.clear(); + } + + ExecContext* + get_exec_context() const { + return exec_ctx_; + } + + size_t + size() const { + return exprs_.size(); + } + + const std::vector>& + exprs() const { + return exprs_; + } + + const std::shared_ptr& + expr(int32_t index) const { + return exprs_[index]; + } + + private: + std::vector> exprs_; + ExecContext* exec_ctx_; +}; + +} //namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp new file mode 100644 index 0000000000..a14ad3da9d --- /dev/null +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -0,0 +1,740 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "JsonContainsExpr.h" + +namespace milvus { +namespace exec { + +void +PhyJsonContainsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (expr_->column_.data_type_) { + case DataType::ARRAY: + case DataType::JSON: { + if (is_index_mode_) { + PanicInfo( + ExprInvalid, + "exists expr for json or array index mode not supportted"); + } + result = EvalJsonContainsForDataSegment(); + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +VectorPtr +PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() { + auto data_type = expr_->column_.data_type_; + switch (expr_->op_) { + case proto::plan::JSONContainsExpr_JSONOp_Contains: + case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: { + if (datatype_is_array(data_type)) { + auto val_type = expr_->vals_[0].val_case(); + switch (val_type) { + case proto::plan::GenericValue::kBoolVal: { + return ExecArrayContains(); + } + case proto::plan::GenericValue::kInt64Val: { + return ExecArrayContains(); + } + case proto::plan::GenericValue::kFloatVal: { + return ExecArrayContains(); + } + case proto::plan::GenericValue::kStringVal: { + return ExecArrayContains(); + } + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", val_type)); + } + } else { + if (expr_->same_type_) { + auto val_type = expr_->vals_[0].val_case(); + switch (val_type) { + case proto::plan::GenericValue::kBoolVal: { + return ExecJsonContains(); + } + case proto::plan::GenericValue::kInt64Val: { + return ExecJsonContains(); + } + case proto::plan::GenericValue::kFloatVal: { + return ExecJsonContains(); + } + case proto::plan::GenericValue::kStringVal: { + return ExecJsonContains(); + } + case proto::plan::GenericValue::kArrayVal: { + return ExecJsonContainsArray(); + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type:{}", + val_type)); + } + } else { + return ExecJsonContainsWithDiffType(); + } + } + break; + } + case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { + if (datatype_is_array(data_type)) { + auto val_type = expr_->vals_[0].val_case(); + switch (val_type) { + case proto::plan::GenericValue::kBoolVal: { + return ExecArrayContainsAll(); + } + case proto::plan::GenericValue::kInt64Val: { + return ExecArrayContainsAll(); + } + case proto::plan::GenericValue::kFloatVal: { + return ExecArrayContainsAll(); + } + case proto::plan::GenericValue::kStringVal: { + return ExecArrayContainsAll(); + } + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", val_type)); + } + } else { + if (expr_->same_type_) { + auto val_type = expr_->vals_[0].val_case(); + switch (val_type) { + case proto::plan::GenericValue::kBoolVal: { + return ExecJsonContainsAll(); + } + case proto::plan::GenericValue::kInt64Val: { + return ExecJsonContainsAll(); + } + case proto::plan::GenericValue::kFloatVal: { + return ExecJsonContainsAll(); + } + case proto::plan::GenericValue::kStringVal: { + return ExecJsonContainsAll(); + } + case proto::plan::GenericValue::kArrayVal: { + return ExecJsonContainsAllArray(); + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type:{}", + val_type)); + } + } else { + return ExecJsonContainsAllWithDiffType(); + } + } + break; + } + default: + PanicInfo(ExprInvalid, + fmt::format("unsupported json contains type {}", + proto::plan::JSONContainsExpr_JSONOp_Name( + expr_->op_))); + } +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecArrayContains() { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + AssertInfo(expr_->column_.nested_path_.size() == 0, + "[ExecArrayContains]nested path must be null"); + + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + std::unordered_set elements; + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + auto execute_sub_batch = [](const milvus::ArrayView* data, + const int size, + bool* res, + const std::unordered_set& elements) { + auto executor = [&](size_t i) { + const auto& array = data[i]; + for (int j = 0; j < array.length(); ++j) { + if (elements.count(array.template get_data(j)) > 0) { + return true; + } + } + return false; + }; + for (int i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContains() { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + std::unordered_set elements; + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + auto execute_sub_batch = [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::unordered_set& elements) { + auto executor = [&](size_t i) { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + for (auto&& it : array) { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (elements.count(val.value()) > 0) { + return true; + } + } + return false; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContainsArray() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + std::vector elements; + for (auto const& element : expr_->vals_) { + elements.emplace_back(GetValueFromProto(element)); + } + auto execute_sub_batch = + [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::vector& elements) { + auto executor = [&](size_t i) -> bool { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + for (auto&& it : array) { + auto val = it.get_array(); + if (val.error()) { + continue; + } + std::vector< + simdjson::simdjson_result> + json_array; + json_array.reserve(val.count_elements()); + for (auto&& e : val) { + json_array.emplace_back(e); + } + for (auto const& element : elements) { + if (CompareTwoJsonArray(json_array, element)) { + return true; + } + } + } + return false; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecArrayContainsAll() { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + AssertInfo(expr_->column_.nested_path_.size() == 0, + "[ExecArrayContainsAll]nested path must be null"); + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + std::unordered_set elements; + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + + auto execute_sub_batch = [](const milvus::ArrayView* data, + const int size, + bool* res, + const std::unordered_set& elements) { + auto executor = [&](size_t i) { + std::unordered_set tmp_elements(elements); + // Note: array can only be iterated once + for (int j = 0; j < data[i].length(); ++j) { + tmp_elements.erase(data[i].template get_data(j)); + if (tmp_elements.size() == 0) { + return true; + } + } + return tmp_elements.size() == 0; + }; + for (int i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContainsAll() { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + std::unordered_set elements; + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + + auto execute_sub_batch = [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::unordered_set& elements) { + auto executor = [&](const size_t i) -> bool { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + std::unordered_set tmp_elements(elements); + // Note: array can only be iterated once + for (auto&& it : array) { + auto val = it.template get(); + if (val.error()) { + continue; + } + tmp_elements.erase(val.value()); + if (tmp_elements.size() == 0) { + return true; + } + } + return tmp_elements.size() == 0; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + + auto elements = expr_->vals_; + std::unordered_set elements_index; + int i = 0; + for (auto& element : elements) { + elements_index.insert(i); + i++; + } + + auto execute_sub_batch = + [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::vector& elements, + const std::unordered_set elements_index) { + auto executor = [&](size_t i) -> bool { + const auto& json = data[i]; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + std::unordered_set tmp_elements_index(elements_index); + for (auto&& it : array) { + int i = -1; + for (auto& element : elements) { + i++; + switch (element.val_case()) { + case proto::plan::GenericValue::kBoolVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.bool_val()) { + tmp_elements_index.erase(i); + } + break; + } + case proto::plan::GenericValue::kInt64Val: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.int64_val()) { + tmp_elements_index.erase(i); + } + break; + } + case proto::plan::GenericValue::kFloatVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.float_val()) { + tmp_elements_index.erase(i); + } + break; + } + case proto::plan::GenericValue::kStringVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.string_val()) { + tmp_elements_index.erase(i); + } + break; + } + case proto::plan::GenericValue::kArrayVal: { + auto val = it.get_array(); + if (val.error()) { + continue; + } + if (CompareTwoJsonArray(val, + element.array_val())) { + tmp_elements_index.erase(i); + } + break; + } + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", + element.val_case())); + } + if (tmp_elements_index.size() == 0) { + return true; + } + } + if (tmp_elements_index.size() == 0) { + return true; + } + } + return tmp_elements_index.size() == 0; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks(execute_sub_batch, + std::nullptr_t{}, + res, + pointer, + elements, + elements_index); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + + std::vector elements; + for (auto const& element : expr_->vals_) { + elements.emplace_back(GetValueFromProto(element)); + } + auto execute_sub_batch = + [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::vector& elements) { + auto executor = [&](const size_t i) { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + std::unordered_set exist_elements_index; + for (auto&& it : array) { + auto val = it.get_array(); + if (val.error()) { + continue; + } + std::vector< + simdjson::simdjson_result> + json_array; + json_array.reserve(val.count_elements()); + for (auto&& e : val) { + json_array.emplace_back(e); + } + for (int index = 0; index < elements.size(); ++index) { + if (CompareTwoJsonArray(json_array, elements[index])) { + exist_elements_index.insert(index); + } + } + if (exist_elements_index.size() == elements.size()) { + return true; + } + } + return exist_elements_index.size() == elements.size(); + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +VectorPtr +PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + + auto elements = expr_->vals_; + std::unordered_set elements_index; + int i = 0; + for (auto& element : elements) { + elements_index.insert(i); + i++; + } + + auto execute_sub_batch = + [](const milvus::Json* data, + const int size, + bool* res, + const std::string& pointer, + const std::vector& elements) { + auto executor = [&](const size_t i) { + auto& json = data[i]; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + // Note: array can only be iterated once + for (auto&& it : array) { + for (auto const& element : elements) { + switch (element.val_case()) { + case proto::plan::GenericValue::kBoolVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.bool_val()) { + return true; + } + break; + } + case proto::plan::GenericValue::kInt64Val: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.int64_val()) { + return true; + } + break; + } + case proto::plan::GenericValue::kFloatVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.float_val()) { + return true; + } + break; + } + case proto::plan::GenericValue::kStringVal: { + auto val = it.template get(); + if (val.error()) { + continue; + } + if (val.value() == element.string_val()) { + return true; + } + break; + } + case proto::plan::GenericValue::kArrayVal: { + auto val = it.get_array(); + if (val.error()) { + continue; + } + if (CompareTwoJsonArray(val, + element.array_val())) { + return true; + } + break; + } + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", + element.val_case())); + } + } + } + return false; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, elements); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/JsonContainsExpr.h b/internal/core/src/exec/expression/JsonContainsExpr.h new file mode 100644 index 0000000000..59287693b4 --- /dev/null +++ b/internal/core/src/exec/expression/JsonContainsExpr.h @@ -0,0 +1,87 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyJsonContainsFilterExpr : public SegmentExpr { + public: + PhyJsonContainsFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + VectorPtr + EvalJsonContainsForDataSegment(); + + template + VectorPtr + ExecJsonContains(); + + template + VectorPtr + ExecArrayContains(); + + template + VectorPtr + ExecJsonContainsAll(); + + template + VectorPtr + ExecArrayContainsAll(); + + VectorPtr + ExecJsonContainsArray(); + + VectorPtr + ExecJsonContainsAllArray(); + + VectorPtr + ExecJsonContainsAllWithDiffType(); + + VectorPtr + ExecJsonContainsWithDiffType(); + + private: + std::shared_ptr expr_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp new file mode 100644 index 0000000000..75e59daac6 --- /dev/null +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp @@ -0,0 +1,51 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "LogicalBinaryExpr.h" + +namespace milvus { +namespace exec { + +void +PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) { + AssertInfo(inputs_.size() == 2, + fmt::format("logical binary expr must has two input, but now {}", + inputs_.size())); + VectorPtr left; + inputs_[0]->Eval(context, left); + VectorPtr right; + inputs_[1]->Eval(context, right); + auto lflat = GetColumnVector(left); + auto rflat = GetColumnVector(right); + auto size = left->size(); + bool* ldata = static_cast(lflat->GetRawData()); + bool* rdata = static_cast(rflat->GetRawData()); + if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::And) { + LogicalElementFunc func; + func(ldata, rdata, size); + } else if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::Or) { + LogicalElementFunc func; + func(ldata, rdata, size); + } else { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported logical operator: {}", + expr_->GetOpTypeString())); + } + result = std::move(left); +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.h b/internal/core/src/exec/expression/LogicalBinaryExpr.h new file mode 100644 index 0000000000..a366e57680 --- /dev/null +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.h @@ -0,0 +1,78 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" +#include "simd/hook.h" + +namespace milvus { +namespace exec { + +enum class LogicalOpType { Invalid = 0, And = 1, Or = 2, Xor = 3, Minus = 4 }; + +template +struct LogicalElementFunc { + void + operator()(bool* left, bool* right, int n) { +#if defined(USE_DYNAMIC_SIMD) + if constexpr (op == LogicalOpType::And) { + milvus::simd::and_bool(left, right, n); + } else if constexpr (op == LogicalOpType::Or) { + milvus::simd::or_bool(left, right, n); + } else { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported logical operator: {}", op)); + } +#else + for (size_t i = 0; i < n; ++i) { + if constexpr (op == LogicalOpType::And) { + left[i] &= right[i]; + } else if constexpr (op == LogicalOpType::Or) { + left[i] |= right[i]; + } else { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported logical operator: {}", op)); + } + } +#endif + } +}; + +class PhyLogicalBinaryExpr : public Expr { + public: + PhyLogicalBinaryExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name) + : Expr(DataType::BOOL, std::move(input), name), expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + std::shared_ptr expr_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp new file mode 100644 index 0000000000..d50fbdba57 --- /dev/null +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp @@ -0,0 +1,44 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "LogicalUnaryExpr.h" +#include "simd/hook.h" + +namespace milvus { +namespace exec { + +void +PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) { + AssertInfo(inputs_.size() == 1, + fmt::format("logical unary expr must has one input, but now {}", + inputs_.size())); + + inputs_[0]->Eval(context, result); + if (expr_->op_type_ == milvus::expr::LogicalUnaryExpr::OpType::LogicalNot) { + auto flat_vec = GetColumnVector(result); + bool* data = static_cast(flat_vec->GetRawData()); +#if defined(USE_DYNAMIC_SIMD) + milvus::simd::invert_bool(data, flat_vec->size()); +#else + for (int i = 0; i < flat_vec->size(); ++i) { + data[i] = !data[i]; + } +#endif + } +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.h b/internal/core/src/exec/expression/LogicalUnaryExpr.h new file mode 100644 index 0000000000..bc7d9a526a --- /dev/null +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.h @@ -0,0 +1,47 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +class PhyLogicalUnaryExpr : public Expr { + public: + PhyLogicalUnaryExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name) + : Expr(DataType::BOOL, std::move(input), name), expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + std::shared_ptr expr_; +}; + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp new file mode 100644 index 0000000000..71930a013c --- /dev/null +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -0,0 +1,540 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "TermExpr.h" +#include "query/Utils.h" +namespace milvus { +namespace exec { + +void +PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + if (is_pk_field_) { + result = ExecPkTermImpl(); + return; + } + switch (expr_->column_.data_type_) { + case DataType::BOOL: { + result = ExecVisitorImpl(); + break; + } + case DataType::INT8: { + result = ExecVisitorImpl(); + break; + } + case DataType::INT16: { + result = ExecVisitorImpl(); + break; + } + case DataType::INT32: { + result = ExecVisitorImpl(); + break; + } + case DataType::INT64: { + result = ExecVisitorImpl(); + break; + } + case DataType::FLOAT: { + result = ExecVisitorImpl(); + break; + } + case DataType::DOUBLE: { + result = ExecVisitorImpl(); + break; + } + case DataType::VARCHAR: { + if (segment_->type() == SegmentType::Growing) { + result = ExecVisitorImpl(); + } else { + result = ExecVisitorImpl(); + } + break; + } + case DataType::JSON: { + if (expr_->vals_.size() == 0) { + result = ExecVisitorImplTemplateJson(); + break; + } + auto type = expr_->vals_[0].val_case(); + switch (type) { + case proto::plan::GenericValue::ValCase::kBoolVal: + result = ExecVisitorImplTemplateJson(); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + result = ExecVisitorImplTemplateJson(); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + result = ExecVisitorImplTemplateJson(); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + result = ExecVisitorImplTemplateJson(); + break; + default: + PanicInfo(DataTypeInvalid, + fmt::format("unknown data type: {}", type)); + } + break; + } + case DataType::ARRAY: { + if (expr_->vals_.size() == 0) { + result = ExecVisitorImplTemplateArray(); + break; + } + auto type = expr_->vals_[0].val_case(); + switch (type) { + case proto::plan::GenericValue::ValCase::kBoolVal: + result = ExecVisitorImplTemplateArray(); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + result = ExecVisitorImplTemplateArray(); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + result = ExecVisitorImplTemplateArray(); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + result = ExecVisitorImplTemplateArray(); + break; + default: + PanicInfo(DataTypeInvalid, + fmt::format("unknown data type: {}", type)); + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +void +PhyTermFilterExpr::InitPkCacheOffset() { + auto id_array = std::make_unique(); + switch (pk_type_) { + case DataType::INT64: { + auto dst_ids = id_array->mutable_int_id(); + for (const auto& id : expr_->vals_) { + dst_ids->add_data(GetValueFromProto(id)); + } + break; + } + case DataType::VARCHAR: { + auto dst_ids = id_array->mutable_str_id(); + for (const auto& id : expr_->vals_) { + dst_ids->add_data(GetValueFromProto(id)); + } + break; + } + default: { + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", pk_type_)); + } + } + + auto [uids, seg_offsets] = + segment_->search_ids(*id_array, query_timestamp_); + cached_bits_.resize(num_rows_); + cached_offsets_ = + std::make_shared(DataType::INT64, seg_offsets.size()); + int64_t* cached_offsets_ptr = (int64_t*)cached_offsets_->GetRawData(); + int i = 0; + for (const auto& offset : seg_offsets) { + auto _offset = (int64_t)offset.get(); + cached_bits_[_offset] = true; + cached_offsets_ptr[i++] = _offset; + } + cached_offsets_inited_ = true; +} + +VectorPtr +PhyTermFilterExpr::ExecPkTermImpl() { + if (!cached_offsets_inited_) { + InitPkCacheOffset(); + } + + auto real_batch_size = current_data_chunk_pos_ + batch_size_ >= num_rows_ + ? num_rows_ - current_data_chunk_pos_ + : batch_size_; + current_data_chunk_pos_ += real_batch_size; + + if (real_batch_size == 0) { + return nullptr; + } + + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + for (size_t i = 0; i < real_batch_size; ++i) { + res[i] = cached_bits_[i]; + } + + std::vector vecs{res_vec, cached_offsets_}; + return std::make_shared(vecs); +} + +template +VectorPtr +PhyTermFilterExpr::ExecVisitorImplTemplateJson() { + if (expr_->is_in_field_) { + return ExecTermJsonVariableInField(); + } else { + return ExecTermJsonFieldInVariable(); + } +} + +template +VectorPtr +PhyTermFilterExpr::ExecVisitorImplTemplateArray() { + if (expr_->is_in_field_) { + return ExecTermArrayVariableInField(); + } else { + return ExecTermArrayFieldInVariable(); + } +} + +template +VectorPtr +PhyTermFilterExpr::ExecTermArrayVariableInField() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + AssertInfo(expr_->vals_.size() == 1, + "element length in json array must be one"); + ValueType target_val = GetValueFromProto(expr_->vals_[0]); + + auto execute_sub_batch = [](const ArrayView* data, + const int size, + bool* res, + const ValueType& target_val) { + auto executor = [&](size_t i) { + for (int i = 0; i < data[i].length(); i++) { + auto val = data[i].template get_data(i); + if (val == target_val) { + return true; + } + } + return false; + }; + for (int i = 0; i < size; ++i) { + executor(i); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, target_val); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyTermFilterExpr::ExecTermArrayFieldInVariable() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + int index = -1; + if (expr_->column_.nested_path_.size() > 0) { + index = std::stoi(expr_->column_.nested_path_[0]); + } + std::unordered_set term_set; + for (const auto& element : expr_->vals_) { + term_set.insert(GetValueFromProto(element)); + } + + if (term_set.empty()) { + for (size_t i = 0; i < real_batch_size; ++i) { + res[i] = false; + } + return res_vec; + } + + auto execute_sub_batch = [](const ArrayView* data, + const int size, + bool* res, + int index, + const std::unordered_set& term_set) { + if (term_set.empty()) { + for (int i = 0; i < size; ++i) { + res[i] = false; + } + } + for (int i = 0; i < size; ++i) { + if (index >= data[i].length()) { + res[i] = false; + continue; + } + auto value = data[i].get_data(index); + res[i] = term_set.find(ValueType(value)) != term_set.end(); + } + }; + + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, index, term_set); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyTermFilterExpr::ExecTermJsonVariableInField() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + AssertInfo(expr_->vals_.size() == 1, + "element length in json array must be one"); + ValueType val = GetValueFromProto(expr_->vals_[0]); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + + auto execute_sub_batch = [](const Json* data, + const int size, + bool* res, + const std::string pointer, + const ValueType& target_val) { + auto executor = [&](size_t i) { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) + return false; + for (auto it = array.begin(); it != array.end(); ++it) { + auto val = (*it).template get(); + if (val.error()) { + return false; + } + if (val.value() == target_val) { + return true; + } + } + return false; + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, val); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyTermFilterExpr::ExecTermJsonFieldInVariable() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + std::unordered_set term_set; + for (const auto& element : expr_->vals_) { + term_set.insert(GetValueFromProto(element)); + } + + if (term_set.empty()) { + for (size_t i = 0; i < real_batch_size; ++i) { + res[i] = false; + } + return res_vec; + } + + auto execute_sub_batch = [](const Json* data, + const int size, + bool* res, + const std::string pointer, + const std::unordered_set& terms) { + auto executor = [&](size_t i) { + auto x = data[i].template at(pointer); + if (x.error()) { + if constexpr (std::is_same_v) { + auto x = data[i].template at(pointer); + if (x.error()) { + return false; + } + + auto value = x.value(); + // if the term set is {1}, and the value is 1.1, we should not return true. + return std::floor(value) == value && + terms.find(ValueType(value)) != terms.end(); + } + return false; + } + return terms.find(ValueType(x.value())) != terms.end(); + }; + for (size_t i = 0; i < size; ++i) { + res[i] = executor(i); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, pointer, term_set); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyTermFilterExpr::ExecVisitorImpl() { + if (is_index_mode_) { + return ExecVisitorImplForIndex(); + } else { + return ExecVisitorImplForData(); + } +} + +template +VectorPtr +PhyTermFilterExpr::ExecVisitorImplForIndex() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + std::vector vals; + for (auto& val : expr_->vals_) { + auto converted_val = GetValueFromProto(val); + // Integral overflow process + if constexpr (std::is_integral_v) { + if (milvus::query::out_of_range(converted_val)) { + continue; + } + } + vals.emplace_back(converted_val); + } + auto execute_sub_batch = [](Index* index_ptr, + const std::vector& vals) { + TermIndexFunc func; + return func(index_ptr, vals.size(), vals.data()); + }; + auto res = ProcessIndexChunks(execute_sub_batch, vals); + AssertInfo(res.size() == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size)); + return std::make_shared(std::move(res)); +} + +template <> +VectorPtr +PhyTermFilterExpr::ExecVisitorImplForIndex() { + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + std::vector vals; + for (auto& val : expr_->vals_) { + vals.emplace_back(GetValueFromProto(val) ? 1 : 0); + } + auto execute_sub_batch = [](Index* index_ptr, + const std::vector& vals) { + TermIndexFunc func; + return std::move(func(index_ptr, vals.size(), (bool*)vals.data())); + }; + auto res = ProcessIndexChunks(execute_sub_batch, vals); + return std::make_shared(std::move(res)); +} + +template +VectorPtr +PhyTermFilterExpr::ExecVisitorImplForData() { + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + std::vector vals; + for (auto& val : expr_->vals_) { + // Integral overflow process + bool overflowed = false; + auto converted_val = GetValueFromProtoWithOverflow(val, overflowed); + if (!overflowed) { + vals.emplace_back(converted_val); + } + } + std::unordered_set vals_set(vals.begin(), vals.end()); + auto execute_sub_batch = [](const T* data, + const int size, + bool* res, + const std::unordered_set& vals) { + TermElementFuncSet func; + for (size_t i = 0; i < size; ++i) { + res[i] = func(vals, data[i]); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, vals_set); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/TermExpr.h b/internal/core/src/exec/expression/TermExpr.h new file mode 100644 index 0000000000..273899963e --- /dev/null +++ b/internal/core/src/exec/expression/TermExpr.h @@ -0,0 +1,134 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace exec { + +template +struct TermElementFuncFlat { + bool + operator()(const T* src, size_t n, T val) { + for (size_t i = 0; i < n; ++i) { + if (src[i] == val) { + return true; + } + } + } +}; + +template +struct TermElementFuncSet { + bool + operator()(const std::unordered_set& srcs, T val) { + return srcs.find(val) != srcs.end(); + } +}; + +template +struct TermIndexFunc { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + FixedVector + operator()(Index* index, size_t n, const IndexInnerType* val) { + return index->In(n, val); + } +}; + +class PhyTermFilterExpr : public SegmentExpr { + public: + PhyTermFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + void + InitPkCacheOffset(); + + VectorPtr + ExecPkTermImpl(); + + template + VectorPtr + ExecVisitorImpl(); + + template + VectorPtr + ExecVisitorImplForIndex(); + + template + VectorPtr + ExecVisitorImplForData(); + + template + VectorPtr + ExecVisitorImplTemplateJson(); + + template + VectorPtr + ExecTermJsonVariableInField(); + + template + VectorPtr + ExecTermJsonFieldInVariable(); + + template + VectorPtr + ExecVisitorImplTemplateArray(); + + template + VectorPtr + ExecTermArrayVariableInField(); + + template + VectorPtr + ExecTermArrayFieldInVariable(); + + private: + std::shared_ptr expr_; + // If expr is like "pk in (..)", can use pk index to optimize + bool cached_offsets_inited_{false}; + ColumnVectorPtr cached_offsets_; + FixedVector cached_bits_; +}; +} //namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp new file mode 100644 index 0000000000..c5c4375d8c --- /dev/null +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -0,0 +1,593 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "UnaryExpr.h" +#include "common/Json.h" + +namespace milvus { +namespace exec { + +void +PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { + switch (expr_->column_.data_type_) { + case DataType::BOOL: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT8: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT16: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT32: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::INT64: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::FLOAT: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::DOUBLE: { + result = ExecRangeVisitorImpl(); + break; + } + case DataType::VARCHAR: { + if (segment_->type() == SegmentType::Growing) { + result = ExecRangeVisitorImpl(); + } else { + result = ExecRangeVisitorImpl(); + } + break; + } + case DataType::JSON: { + auto val_type = expr_->val_.val_case(); + switch (val_type) { + case proto::plan::GenericValue::ValCase::kBoolVal: + result = ExecRangeVisitorImplJson(); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + result = ExecRangeVisitorImplJson(); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + result = ExecRangeVisitorImplJson(); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + result = ExecRangeVisitorImplJson(); + break; + case proto::plan::GenericValue::ValCase::kArrayVal: + result = ExecRangeVisitorImplJson(); + break; + default: + PanicInfo(DataTypeInvalid, + fmt::format("unknown data type: {}", val_type)); + } + break; + } + case DataType::ARRAY: { + auto val_type = expr_->val_.val_case(); + switch (val_type) { + case proto::plan::GenericValue::ValCase::kBoolVal: + result = ExecRangeVisitorImplArray(); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + result = ExecRangeVisitorImplArray(); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + result = ExecRangeVisitorImplArray(); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + result = ExecRangeVisitorImplArray(); + break; + case proto::plan::GenericValue::ValCase::kArrayVal: + result = ExecRangeVisitorImplArray(); + break; + default: + PanicInfo(DataTypeInvalid, + fmt::format("unknown data type: {}", val_type)); + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", + expr_->column_.data_type_)); + } +} + +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + + ValueType val = GetValueFromProto(expr_->val_); + auto op_type = expr_->op_type_; + int index = -1; + if (expr_->column_.nested_path_.size() > 0) { + index = std::stoi(expr_->column_.nested_path_[0]); + } + auto execute_sub_batch = [op_type](const milvus::ArrayView* data, + const int size, + bool* res, + ValueType val, + int index) { + switch (op_type) { + case proto::plan::GreaterThan: { + UnaryElementFuncForArray + func; + func(data, size, val, index, res); + break; + } + case proto::plan::GreaterEqual: { + UnaryElementFuncForArray + func; + func(data, size, val, index, res); + break; + } + case proto::plan::LessThan: { + UnaryElementFuncForArray func; + func(data, size, val, index, res); + break; + } + case proto::plan::LessEqual: { + UnaryElementFuncForArray + func; + func(data, size, val, index, res); + break; + } + case proto::plan::Equal: { + UnaryElementFuncForArray func; + func(data, size, val, index, res); + break; + } + case proto::plan::NotEqual: { + UnaryElementFuncForArray func; + func(data, size, val, index, res); + break; + } + case proto::plan::PrefixMatch: { + UnaryElementFuncForArray + func; + func(data, size, val, index, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported operator type for unary expr: {}", + op_type)); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, val, index); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + ExprValueType val = GetValueFromProto(expr_->val_); + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto op_type = expr_->op_type_; + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); + +#define UnaryRangeJSONCompare(cmp) \ + do { \ + auto x = data[i].template at(pointer); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = data[i].template at(pointer); \ + res[i] = !x.error() && (cmp); \ + break; \ + } \ + res[i] = false; \ + break; \ + } \ + res[i] = (cmp); \ + } while (false) + +#define UnaryRangeJSONCompareNotEqual(cmp) \ + do { \ + auto x = data[i].template at(pointer); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = data[i].template at(pointer); \ + res[i] = x.error() || (cmp); \ + break; \ + } \ + res[i] = true; \ + break; \ + } \ + res[i] = (cmp); \ + } while (false) + + auto execute_sub_batch = [op_type, pointer](const milvus::Json* data, + const int size, + bool* res, + ExprValueType val) { + switch (op_type) { + case proto::plan::GreaterThan: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare(x.value() > val); + } + } + break; + } + case proto::plan::GreaterEqual: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare(x.value() >= val); + } + } + break; + } + case proto::plan::LessThan: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare(x.value() < val); + } + } + break; + } + case proto::plan::LessEqual: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare(x.value() <= val); + } + } + break; + } + case proto::plan::Equal: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + res[i] = false; + continue; + } + res[i] = CompareTwoJsonArray(array, val); + } else { + UnaryRangeJSONCompare(x.value() == val); + } + } + break; + } + case proto::plan::NotEqual: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + auto doc = data[i].doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + res[i] = false; + continue; + } + res[i] = !CompareTwoJsonArray(array, val); + } else { + UnaryRangeJSONCompareNotEqual(x.value() != val); + } + } + break; + } + case proto::plan::PrefixMatch: { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare(milvus::query::Match( + ExprValueType(x.value()), val, op_type)); + } + } + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported operator type for unary expr: {}", + op_type)); + } + }; + int64_t processed_size = ProcessDataChunks( + execute_sub_batch, std::nullptr_t{}, res, val); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl() { + if (is_index_mode_) { + return ExecRangeVisitorImplForIndex(); + } else { + return ExecRangeVisitorImplForData(); + } +} + +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + if (auto res = PreCheckOverflow()) { + return res; + } + + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + auto op_type = expr_->op_type_; + auto execute_sub_batch = [op_type](Index* index_ptr, IndexInnerType val) { + FixedVector res; + switch (op_type) { + case proto::plan::GreaterThan: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::GreaterEqual: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::LessThan: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::LessEqual: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::Equal: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::NotEqual: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + case proto::plan::PrefixMatch: { + UnaryIndexFunc func; + res = std::move(func(index_ptr, val)); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported operator type for unary expr: {}", + op_type)); + } + return res; + }; + auto val = GetValueFromProto(expr_->val_); + auto res = ProcessIndexChunks(execute_sub_batch, val); + AssertInfo(res.size() == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size)); + return std::make_shared(std::move(res)); +} + +template +ColumnVectorPtr +PhyUnaryRangeFilterExpr::PreCheckOverflow() { + if constexpr (std::is_integral_v && !std::is_same_v) { + int64_t val = GetValueFromProto(expr_->val_); + + if (milvus::query::out_of_range(val)) { + int64_t batch_size = overflow_check_pos_ + batch_size_ >= num_rows_ + ? num_rows_ - overflow_check_pos_ + : batch_size_; + overflow_check_pos_ += batch_size; + if (cached_overflow_res_ != nullptr && + cached_overflow_res_->size() == batch_size) { + return cached_overflow_res_; + } + switch (expr_->op_type_) { + case proto::plan::GreaterThan: + case proto::plan::GreaterEqual: { + auto res_vec = std::make_shared( + DataType::BOOL, batch_size); + cached_overflow_res_ = res_vec; + bool* res = (bool*)res_vec->GetRawData(); + if (milvus::query::lt_lb(val)) { + for (size_t i = 0; i < batch_size; ++i) { + res[i] = true; + } + return res_vec; + } + return res_vec; + } + case proto::plan::LessThan: + case proto::plan::LessEqual: { + auto res_vec = std::make_shared( + DataType::BOOL, batch_size); + cached_overflow_res_ = res_vec; + bool* res = (bool*)res_vec->GetRawData(); + if (milvus::query::gt_ub(val)) { + for (size_t i = 0; i < batch_size; ++i) { + res[i] = true; + } + return res_vec; + } + return res_vec; + } + case proto::plan::Equal: { + auto res_vec = std::make_shared( + DataType::BOOL, batch_size); + cached_overflow_res_ = res_vec; + bool* res = (bool*)res_vec->GetRawData(); + for (size_t i = 0; i < batch_size; ++i) { + res[i] = false; + } + return res_vec; + } + case proto::plan::NotEqual: { + auto res_vec = std::make_shared( + DataType::BOOL, batch_size); + cached_overflow_res_ = res_vec; + bool* res = (bool*)res_vec->GetRawData(); + for (size_t i = 0; i < batch_size; ++i) { + res[i] = true; + } + return res_vec; + } + default: { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported range node {}", + expr_->op_type_)); + } + } + } + } + return nullptr; +} + +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + if (auto res = PreCheckOverflow()) { + return res; + } + + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + IndexInnerType val = GetValueFromProto(expr_->val_); + auto res_vec = + std::make_shared(DataType::BOOL, real_batch_size); + bool* res = (bool*)res_vec->GetRawData(); + auto expr_type = expr_->op_type_; + auto execute_sub_batch = [expr_type](const T* data, + const int size, + bool* res, + IndexInnerType val) { + switch (expr_type) { + case proto::plan::GreaterThan: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::GreaterEqual: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::LessThan: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::LessEqual: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::Equal: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::NotEqual: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + case proto::plan::PrefixMatch: { + UnaryElementFunc func; + func(data, size, val, res); + break; + } + default: + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported operator type for unary expr: {}", + expr_type)); + } + }; + auto skip_index_func = [expr_type, val](const SkipIndex& skip_index, + FieldId field_id, + int64_t chunk_id) { + return skip_index.CanSkipUnaryRange( + field_id, chunk_id, expr_type, val); + }; + int64_t processed_size = + ProcessDataChunks(execute_sub_batch, skip_index_func, res, val); + AssertInfo(processed_size == real_batch_size, + fmt::format("internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_size, + real_batch_size)); + return res_vec; +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h new file mode 100644 index 0000000000..82c5d40181 --- /dev/null +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -0,0 +1,220 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "index/Meta.h" +#include "segcore/SegmentInterface.h" +#include "query/Utils.h" + +namespace milvus { +namespace exec { + +template +struct UnaryElementFunc { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + void + operator()(const T* src, size_t size, IndexInnerType val, bool* res) { + for (int i = 0; i < size; ++i) { + if constexpr (op == proto::plan::OpType::Equal) { + res[i] = src[i] == val; + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res[i] = src[i] != val; + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res[i] = src[i] > val; + } else if constexpr (op == proto::plan::OpType::LessThan) { + res[i] = src[i] < val; + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res[i] = src[i] >= val; + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res[i] = src[i] <= val; + } else if constexpr (op == proto::plan::OpType::PrefixMatch) { + res[i] = milvus::query::Match( + src[i], val, proto::plan::OpType::PrefixMatch); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for UnaryElementFunc", + op)); + } + } + } +}; + +#define UnaryArrayCompare(cmp) \ + do { \ + if constexpr (std::is_same_v) { \ + res[i] = false; \ + } else { \ + if (index >= src[i].length()) { \ + res[i] = false; \ + continue; \ + } \ + auto array_data = src[i].template get_data(index); \ + res[i] = (cmp); \ + } \ + } while (false) + +template +struct UnaryElementFuncForArray { + using GetType = std::conditional_t, + std::string_view, + ValueType>; + void + operator()(const ArrayView* src, + size_t size, + ValueType val, + int index, + bool* res) { + for (int i = 0; i < size; ++i) { + if constexpr (op == proto::plan::OpType::Equal) { + if constexpr (std::is_same_v) { + res[i] = src[i].is_same_array(val); + } else { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto array_data = src[i].template get_data(index); + res[i] = array_data == val; + } + } else if constexpr (op == proto::plan::OpType::NotEqual) { + if constexpr (std::is_same_v) { + res[i] = !src[i].is_same_array(val); + } else { + if (index >= src[i].length()) { + res[i] = false; + continue; + } + auto array_data = src[i].template get_data(index); + res[i] = array_data != val; + } + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + UnaryArrayCompare(array_data > val); + } else if constexpr (op == proto::plan::OpType::LessThan) { + UnaryArrayCompare(array_data < val); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + UnaryArrayCompare(array_data >= val); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + UnaryArrayCompare(array_data <= val); + } else if constexpr (op == proto::plan::OpType::PrefixMatch) { + UnaryArrayCompare(milvus::query::Match(array_data, val, op)); + } else { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported op_type:{} for " + "UnaryElementFuncForArray", + op)); + } + } + } +}; + +template +struct UnaryIndexFunc { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + FixedVector + operator()(Index* index, IndexInnerType val) { + if constexpr (op == proto::plan::OpType::Equal) { + return index->In(1, &val); + } else if constexpr (op == proto::plan::OpType::NotEqual) { + return index->NotIn(1, &val); + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + return index->Range(val, OpType::GreaterThan); + } else if constexpr (op == proto::plan::OpType::LessThan) { + return index->Range(val, OpType::LessThan); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + return index->Range(val, OpType::GreaterEqual); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + return index->Range(val, OpType::LessEqual); + } else if constexpr (op == proto::plan::OpType::PrefixMatch) { + auto dataset = std::make_unique(); + dataset->Set(milvus::index::OPERATOR_TYPE, + proto::plan::OpType::PrefixMatch); + dataset->Set(milvus::index::PREFIX_VALUE, val); + return index->Query(std::move(dataset)); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for UnaryIndexFunc", op)); + } + } +}; + +class PhyUnaryRangeFilterExpr : public SegmentExpr { + public: + PhyUnaryRangeFilterExpr( + const std::vector>& input, + const std::shared_ptr& expr, + const std::string& name, + const segcore::SegmentInternalInterface* segment, + Timestamp query_timestamp, + int64_t batch_size) + : SegmentExpr(std::move(input), + name, + segment, + expr->column_.field_id_, + query_timestamp, + batch_size), + expr_(expr) { + } + + void + Eval(EvalCtx& context, VectorPtr& result) override; + + private: + template + VectorPtr + ExecRangeVisitorImpl(); + + template + VectorPtr + ExecRangeVisitorImplForIndex(); + + template + VectorPtr + ExecRangeVisitorImplForData(); + + template + VectorPtr + ExecRangeVisitorImplJson(); + + template + VectorPtr + ExecRangeVisitorImplArray(); + + // Check overflow and cache result for performace + template + ColumnVectorPtr + PreCheckOverflow(); + + private: + std::shared_ptr expr_; + ColumnVectorPtr cached_overflow_res_{nullptr}; + int64_t overflow_check_pos_{0}; +}; +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/Utils.h b/internal/core/src/exec/expression/Utils.h new file mode 100644 index 0000000000..96d933d95e --- /dev/null +++ b/internal/core/src/exec/expression/Utils.h @@ -0,0 +1,166 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "common/EasyAssert.h" +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/expression/Expr.h" +#include "segcore/SegmentInterface.h" +#include "query/Utils.h" + +namespace milvus { +namespace exec { + +static ColumnVectorPtr +GetColumnVector(const VectorPtr& result) { + ColumnVectorPtr res; + if (auto convert_vector = std::dynamic_pointer_cast(result)) { + res = convert_vector; + } else if (auto convert_vector = + std::dynamic_pointer_cast(result)) { + if (auto convert_flat_vector = std::dynamic_pointer_cast( + convert_vector->child(0))) { + res = convert_flat_vector; + } else { + PanicInfo( + UnexpectedError, + "RowVector result must have a first ColumnVector children"); + } + } else { + PanicInfo(UnexpectedError, + "expr result must have a ColumnVector or RowVector result"); + } + return res; +} + +template +bool +CompareTwoJsonArray(T arr1, const proto::plan::Array& arr2) { + int json_array_length = 0; + if constexpr (std::is_same_v< + T, + simdjson::simdjson_result>) { + json_array_length = arr1.count_elements(); + } + if constexpr (std::is_same_v>>) { + json_array_length = arr1.size(); + } + if (arr2.array_size() != json_array_length) { + return false; + } + int i = 0; + for (auto&& it : arr1) { + switch (arr2.array(i).val_case()) { + case proto::plan::GenericValue::kBoolVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).bool_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kInt64Val: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).int64_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kFloatVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).float_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kStringVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).string_val()) { + return false; + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + arr2.array(i).val_case())); + } + i++; + } + return true; +} + +template +T +GetValueFromProtoInternal(const milvus::proto::plan::GenericValue& value_proto, + bool& overflowed) { + if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kBoolVal); + return static_cast(value_proto.bool_val()); + } else if constexpr (std::is_integral_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kInt64Val); + auto val = value_proto.int64_val(); + if (milvus::query::out_of_range(val)) { + overflowed = true; + return T(); + } else { + return static_cast(val); + } + } else if constexpr (std::is_floating_point_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kFloatVal); + return static_cast(value_proto.float_val()); + } else if constexpr (std::is_same_v || + std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kStringVal); + return static_cast(value_proto.string_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kArrayVal); + return static_cast(value_proto.array_val()); + } else if constexpr (std::is_same_v) { + return static_cast(value_proto); + } else { + PanicInfo(Unsupported, + fmt::format("unsupported generic value {}", + value_proto.DebugString())); + } +} + +template +T +GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { + bool dummy_overflowed = false; + return GetValueFromProtoInternal(value_proto, dummy_overflowed); +} + +template +T +GetValueFromProtoWithOverflow( + const milvus::proto::plan::GenericValue& value_proto, bool& overflowed) { + return GetValueFromProtoInternal(value_proto, overflowed); +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/VectorFunction.h b/internal/core/src/exec/expression/VectorFunction.h new file mode 100644 index 0000000000..1e6be5081c --- /dev/null +++ b/internal/core/src/exec/expression/VectorFunction.h @@ -0,0 +1,47 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "common/Vector.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { + +class VectorFunction { + public: + virtual ~VectorFunction() = default; + + virtual void + Apply(std::vector& args, + DataType output_type, + EvalCtx& context, + VectorPtr& result) const = 0; +}; + +std::shared_ptr +GetVectorFunction(const std::string& name, + const std::vector& input_types, + const QueryConfig& config); + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/CallbackSink.h b/internal/core/src/exec/operator/CallbackSink.h new file mode 100644 index 0000000000..5e5c7479b5 --- /dev/null +++ b/internal/core/src/exec/operator/CallbackSink.h @@ -0,0 +1,89 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "exec/operator/Operator.h" + +namespace milvus { +namespace exec { +class CallbackSink : public Operator { + public: + CallbackSink( + int32_t operator_id, + DriverContext* ctx, + std::function callback) + : Operator(ctx, DataType::NONE, operator_id, "N/A", "CallbackSink"), + callback_(callback) { + } + + void + AddInput(RowVectorPtr& input) override { + blocking_reason_ = callback_(input, &future_); + } + + RowVectorPtr + GetOutput() override { + return nullptr; + } + + void + NoMoreInput() override { + Operator::NoMoreInput(); + Close(); + } + + bool + NeedInput() const override { + return callback_ != nullptr; + } + + bool + IsFilter() override { + return false; + } + + bool + IsFinished() override { + return no_more_input_; + } + + BlockingReason + IsBlocked(ContinueFuture* future) override { + if (blocking_reason_ != BlockingReason::kNotBlocked) { + *future = std::move(future_); + blocking_reason_ = BlockingReason::kNotBlocked; + return BlockingReason::kWaitForConsumer; + } + return BlockingReason::kNotBlocked; + } + + private: + void + Close() override { + if (callback_) { + callback_(nullptr, nullptr); + callback_ = nullptr; + } + } + + ContinueFuture future_; + BlockingReason blocking_reason_{BlockingReason::kNotBlocked}; + std::function callback_; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/FilterBits.cpp b/internal/core/src/exec/operator/FilterBits.cpp new file mode 100644 index 0000000000..a1a06b9013 --- /dev/null +++ b/internal/core/src/exec/operator/FilterBits.cpp @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "FilterBits.h" + +namespace milvus { +namespace exec { +FilterBits::FilterBits( + int32_t operator_id, + DriverContext* driverctx, + const std::shared_ptr& filter) + : Operator(driverctx, + filter->output_type(), + operator_id, + filter->id(), + "FilterBits") { + ExecContext* exec_context = operator_context_->get_exec_context(); + QueryContext* query_context = exec_context->get_query_context(); + std::vector filters; + filters.emplace_back(filter->filter()); + exprs_ = std::make_unique(filters, exec_context); + need_process_rows_ = query_context->get_segment()->get_active_count( + query_context->get_query_timestamp()); + num_processed_rows_ = 0; +} + +void +FilterBits::AddInput(RowVectorPtr& input) { + input_ = std::move(input); +} + +bool +FilterBits::AllInputProcessed() { + if (num_processed_rows_ == need_process_rows_) { + input_ = nullptr; + return true; + } + return false; +} + +bool +FilterBits::IsFinished() { + return AllInputProcessed(); +} + +RowVectorPtr +FilterBits::GetOutput() { + if (AllInputProcessed()) { + return nullptr; + } + + EvalCtx eval_ctx( + operator_context_->get_exec_context(), exprs_.get(), input_.get()); + + exprs_->Eval(0, 1, true, eval_ctx, results_); + + AssertInfo(results_.size() == 1 && results_[0] != nullptr, + "FilterBits result size should be one and not be nullptr"); + + if (results_[0]->type() == DataType::ROW) { + auto row_vec = std::dynamic_pointer_cast(results_[0]); + num_processed_rows_ += row_vec->child(0)->size(); + } else { + num_processed_rows_ += results_[0]->size(); + } + return std::make_shared(results_); +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/FilterBits.h b/internal/core/src/exec/operator/FilterBits.h new file mode 100644 index 0000000000..462c8dc5e5 --- /dev/null +++ b/internal/core/src/exec/operator/FilterBits.h @@ -0,0 +1,74 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "exec/Driver.h" +#include "exec/expression/Expr.h" +#include "exec/operator/Operator.h" +#include "exec/QueryContext.h" + +namespace milvus { +namespace exec { +class FilterBits : public Operator { + public: + FilterBits(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& filter); + + bool + IsFilter() override { + return true; + } + + bool + NeedInput() const override { + return !input_; + } + + void + AddInput(RowVectorPtr& input) override; + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override; + + void + Close() override { + Operator::Close(); + exprs_->Clear(); + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + bool + AllInputProcessed(); + + private: + std::unique_ptr exprs_; + int64_t num_processed_rows_; + int64_t need_process_rows_; +}; +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/Operator.cpp b/internal/core/src/exec/operator/Operator.cpp new file mode 100644 index 0000000000..972482c797 --- /dev/null +++ b/internal/core/src/exec/operator/Operator.cpp @@ -0,0 +1,21 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Operator.h" + +namespace milvus { +namespace exec {} +} // namespace milvus diff --git a/internal/core/src/exec/operator/Operator.h b/internal/core/src/exec/operator/Operator.h new file mode 100644 index 0000000000..1a258873e7 --- /dev/null +++ b/internal/core/src/exec/operator/Operator.h @@ -0,0 +1,197 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common/Types.h" +#include "common/Vector.h" +#include "exec/Driver.h" +#include "exec/Task.h" +#include "exec/QueryContext.h" +#include "plan/PlanNode.h" + +namespace milvus { +namespace exec { + +class OperatorContext { + public: + OperatorContext(DriverContext* driverCtx, + const plan::PlanNodeId& plannodeid, + int32_t operator_id, + const std::string& operator_type = "") + : driver_context_(driverCtx), + plannode_id_(plannodeid), + operator_id_(operator_id), + operator_type_(operator_type) { + } + + ExecContext* + get_exec_context() const { + if (!exec_context_) { + exec_context_ = std::make_unique( + driver_context_->task_->query_context().get()); + } + return exec_context_.get(); + } + + const std::shared_ptr& + get_task() const { + return driver_context_->task_; + } + + const std::string& + get_task_id() const { + return driver_context_->task_->taskid(); + } + + DriverContext* + get_driver_context() const { + return driver_context_; + } + + const plan::PlanNodeId& + get_plannode_id() const { + return plannode_id_; + } + + const std::string& + get_operator_type() const { + return operator_type_; + } + + const int32_t + get_operator_id() const { + return operator_id_; + } + + private: + DriverContext* driver_context_; + plan::PlanNodeId plannode_id_; + int32_t operator_id_; + std::string operator_type_; + + mutable std::unique_ptr exec_context_; +}; + +class Operator { + public: + Operator(DriverContext* ctx, + DataType output_type, + int32_t operator_id, + const std::string& plannode_id, + const std::string& operator_type = "") + : operator_context_(std::make_unique( + ctx, plannode_id, operator_id, operator_type)) { + } + + virtual ~Operator() = default; + + virtual bool + NeedInput() const = 0; + + virtual void + AddInput(RowVectorPtr& input) = 0; + + virtual void + NoMoreInput() { + no_more_input_ = true; + } + + virtual RowVectorPtr + GetOutput() = 0; + + virtual bool + IsFinished() = 0; + + virtual bool + IsFilter() = 0; + + virtual BlockingReason + IsBlocked(ContinueFuture* future) = 0; + + virtual void + Close() { + input_ = nullptr; + results_.clear(); + } + + virtual bool + PreserveOrder() const { + return false; + } + + const std::string& + get_operator_type() const { + return operator_context_->get_operator_type(); + } + + const int32_t + get_operator_id() const { + return operator_context_->get_operator_id(); + } + + const plan::PlanNodeId& + get_plannode_id() const { + return operator_context_->get_plannode_id(); + } + + protected: + std::unique_ptr operator_context_; + + DataType output_type_; + + RowVectorPtr input_; + + bool no_more_input_{false}; + + std::vector results_; +}; + +class SourceOperator : public Operator { + public: + SourceOperator(DriverContext* driver_ctx, + DataType out_type, + int32_t operator_id, + const std::string& plannode_id, + const std::string& operator_type) + : Operator( + driver_ctx, out_type, operator_id, plannode_id, operator_type) { + } + + bool + NeedInput() const override { + return false; + } + + void + AddInput(RowVectorPtr& /* unused */) override { + throw NotImplementedException( + "SourceOperator does not support addInput()"); + } + + void + NoMoreInput() override { + throw NotImplementedException( + "SourceOperator does not support noMoreInput()"); + } +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h new file mode 100644 index 0000000000..068e69260f --- /dev/null +++ b/internal/core/src/expr/ITypeExpr.h @@ -0,0 +1,557 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "common/Schema.h" +#include "common/Types.h" +#include "pb/plan.pb.h" + +namespace milvus { +namespace expr { + +struct ColumnInfo { + FieldId field_id_; + DataType data_type_; + std::vector nested_path_; + + ColumnInfo(const proto::plan::ColumnInfo& column_info) + : field_id_(column_info.field_id()), + data_type_(static_cast(column_info.data_type())), + nested_path_(column_info.nested_path().begin(), + column_info.nested_path().end()) { + } + + ColumnInfo(FieldId field_id, + DataType data_type, + std::vector nested_path = {}) + : field_id_(field_id), + data_type_(data_type), + nested_path_(std::move(nested_path)) { + } + + bool + operator==(const ColumnInfo& other) { + if (field_id_ != other.field_id_) { + return false; + } + + if (data_type_ != other.data_type_) { + return false; + } + + for (int i = 0; i < nested_path_.size(); ++i) { + if (nested_path_[i] != other.nested_path_[i]) { + return false; + } + } + + return true; + } + + std::string + ToString() const { + return fmt::format("[FieldId:{}, data_type:{}, nested_path:{}]", + std::to_string(field_id_.get()), + data_type_, + milvus::Join(nested_path_, ",")); + } +}; + +/** + * @brief Base class for all exprs + * a strongly-typed expression, such as literal, function call, etc... + */ +class ITypeExpr { + public: + explicit ITypeExpr(DataType type) : type_(type), inputs_{} { + } + + ITypeExpr(DataType type, + std::vector> inputs) + : type_(type), inputs_{std::move(inputs)} { + } + + virtual ~ITypeExpr() = default; + + const std::vector>& + inputs() const { + return inputs_; + } + + DataType + type() const { + return type_; + } + + virtual std::string + ToString() const = 0; + + const std::vector>& + inputs() { + return inputs_; + } + + protected: + DataType type_; + std::vector> inputs_; +}; + +using TypedExprPtr = std::shared_ptr; + +class InputTypeExpr : public ITypeExpr { + public: + InputTypeExpr(DataType type) : ITypeExpr(type) { + } + + std::string + ToString() const override { + return "ROW"; + } +}; + +using InputTypeExprPtr = std::shared_ptr; + +class CallTypeExpr : public ITypeExpr { + public: + CallTypeExpr(DataType type, + const std::vector& inputs, + std::string fun_name) + : ITypeExpr{type, std::move(inputs)} { + } + + virtual ~CallTypeExpr() = default; + + virtual const std::string& + name() const { + return name_; + } + + std::string + ToString() const override { + std::string str{}; + str += name(); + str += "("; + for (size_t i = 0; i < inputs_.size(); ++i) { + if (i != 0) { + str += ","; + } + str += inputs_[i]->ToString(); + } + str += ")"; + return str; + } + + private: + std::string name_; +}; + +using CallTypeExprPtr = std::shared_ptr; + +class FieldAccessTypeExpr : public ITypeExpr { + public: + FieldAccessTypeExpr(DataType type, const std::string& name) + : ITypeExpr{type}, name_(name), is_input_column_(true) { + } + + FieldAccessTypeExpr(DataType type, + const TypedExprPtr& input, + const std::string& name) + : ITypeExpr{type, {std::move(input)}}, name_(name) { + is_input_column_ = + dynamic_cast(inputs_[0].get()) != nullptr; + } + + bool + is_input_column() const { + return is_input_column_; + } + + std::string + ToString() const override { + if (inputs_.empty()) { + return fmt::format("{}", name_); + } + + return fmt::format("{}[{}]", inputs_[0]->ToString(), name_); + } + + private: + std::string name_; + bool is_input_column_; +}; + +using FieldAccessTypeExprPtr = std::shared_ptr; + +/** + * @brief Base class for all milvus filter exprs, output type must be BOOL + * a strongly-typed expression, such as literal, function call, etc... + */ +class ITypeFilterExpr : public ITypeExpr { + public: + ITypeFilterExpr() : ITypeExpr(DataType::BOOL) { + } + + ITypeFilterExpr(std::vector> inputs) + : ITypeExpr(DataType::BOOL, std::move(inputs)) { + } + + virtual ~ITypeFilterExpr() = default; +}; + +class UnaryRangeFilterExpr : public ITypeFilterExpr { + public: + explicit UnaryRangeFilterExpr(const ColumnInfo& column, + proto::plan::OpType op_type, + const proto::plan::GenericValue& val) + : ITypeFilterExpr(), column_(column), op_type_(op_type), val_(val) { + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "UnaryRangeFilterExpr: {columnInfo:" << column_.ToString() + << " op_type:" << milvus::proto::plan::OpType_Name(op_type_) + << " val:" << val_.DebugString() << "}"; + return ss.str(); + } + + public: + const ColumnInfo column_; + const proto::plan::OpType op_type_; + const proto::plan::GenericValue val_; +}; + +class AlwaysTrueExpr : public ITypeFilterExpr { + public: + explicit AlwaysTrueExpr() { + } + + std::string + ToString() const override { + return "AlwaysTrue expr"; + } +}; + +class ExistsExpr : public ITypeFilterExpr { + public: + explicit ExistsExpr(const ColumnInfo& column) + : ITypeFilterExpr(), column_(column) { + } + + std::string + ToString() const override { + return "{Exists Expression - Column: " + column_.ToString() + "}"; + } + + const ColumnInfo column_; +}; + +class LogicalUnaryExpr : public ITypeFilterExpr { + public: + enum class OpType { Invalid = 0, LogicalNot = 1 }; + + explicit LogicalUnaryExpr(const OpType op_type, const TypedExprPtr& child) + : op_type_(op_type) { + inputs_.emplace_back(child); + } + + std::string + ToString() const override { + std::string opTypeString; + + switch (op_type_) { + case OpType::LogicalNot: + opTypeString = "Logical NOT"; + break; + default: + opTypeString = "Invalid Operator"; + break; + } + + return fmt::format("LogicalUnaryExpr:[{} - Child: {}]", + opTypeString, + inputs_[0]->ToString()); + } + + const OpType op_type_; +}; + +class TermFilterExpr : public ITypeFilterExpr { + public: + explicit TermFilterExpr(const ColumnInfo& column, + const std::vector& vals, + bool is_in_field = false) + : ITypeFilterExpr(), + column_(column), + vals_(vals), + is_in_field_(is_in_field) { + } + + std::string + ToString() const override { + std::string values; + + for (const auto& val : vals_) { + values += val.DebugString() + ", "; + } + + std::stringstream ss; + ss << "TermFilterExpr:[Column: " << column_.ToString() << ", Values: [" + << values << "]" + << ", Is In Field: " << (is_in_field_ ? "true" : "false") << "]"; + + return ss.str(); + } + + public: + const ColumnInfo column_; + const std::vector vals_; + const bool is_in_field_; +}; + +class LogicalBinaryExpr : public ITypeFilterExpr { + public: + enum class OpType { Invalid = 0, And = 1, Or = 2 }; + + explicit LogicalBinaryExpr(OpType op_type, + const TypedExprPtr& left, + const TypedExprPtr& right) + : ITypeFilterExpr(), op_type_(op_type) { + inputs_.emplace_back(left); + inputs_.emplace_back(right); + } + + std::string + GetOpTypeString() const { + switch (op_type_) { + case OpType::Invalid: + return "Invalid"; + case OpType::And: + return "And"; + case OpType::Or: + return "Or"; + default: + return "Unknown"; // Handle the default case if necessary + } + } + + std::string + ToString() const override { + return fmt::format("LogicalBinaryExpr:[{} - Left: {}, Right: {}]", + GetOpTypeString(), + inputs_[0]->ToString(), + inputs_[1]->ToString()); + } + + std::string + name() const { + return GetOpTypeString(); + } + + public: + const OpType op_type_; +}; + +class BinaryRangeFilterExpr : public ITypeFilterExpr { + public: + BinaryRangeFilterExpr(const ColumnInfo& column, + const proto::plan::GenericValue& lower_value, + const proto::plan::GenericValue& upper_value, + bool lower_inclusive, + bool upper_inclusive) + : ITypeFilterExpr(), + column_(column), + lower_val_(lower_value), + upper_val_(upper_value), + lower_inclusive_(lower_inclusive), + upper_inclusive_(upper_inclusive) { + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "BinaryRangeFilterExpr:[Column: " << column_.ToString() + << ", Lower Value: " << lower_val_.DebugString() + << ", Upper Value: " << upper_val_.DebugString() + << ", Lower Inclusive: " << (lower_inclusive_ ? "true" : "false") + << ", Upper Inclusive: " << (upper_inclusive_ ? "true" : "false") + << "]"; + + return ss.str(); + } + + const ColumnInfo column_; + const proto::plan::GenericValue lower_val_; + const proto::plan::GenericValue upper_val_; + const bool lower_inclusive_; + const bool upper_inclusive_; +}; + +class BinaryArithOpEvalRangeExpr : public ITypeFilterExpr { + public: + BinaryArithOpEvalRangeExpr(const ColumnInfo& column, + const proto::plan::OpType op_type, + const proto::plan::ArithOpType arith_op_type, + const proto::plan::GenericValue value, + const proto::plan::GenericValue right_operand) + : column_(column), + op_type_(op_type), + arith_op_type_(arith_op_type), + right_operand_(right_operand), + value_(value) { + } + + std::string + ToString() const override { + std::stringstream ss; + ss << "BinaryArithOpEvalRangeExpr:[Column: " << column_.ToString() + << ", Operator Type: " << milvus::proto::plan::OpType_Name(op_type_) + << ", Arith Operator Type: " + << milvus::proto::plan::ArithOpType_Name(arith_op_type_) + << ", Value: " << value_.DebugString() + << ", Right Operand: " << right_operand_.DebugString() << "]"; + + return ss.str(); + } + + public: + const ColumnInfo column_; + const proto::plan::OpType op_type_; + const proto::plan::ArithOpType arith_op_type_; + const proto::plan::GenericValue right_operand_; + const proto::plan::GenericValue value_; +}; + +class CompareExpr : public ITypeFilterExpr { + public: + CompareExpr(const FieldId& left_field, + const FieldId& right_field, + DataType left_data_type, + DataType right_data_type, + proto::plan::OpType op_type) + : left_field_id_(left_field), + right_field_id_(right_field), + left_data_type_(left_data_type), + right_data_type_(right_data_type), + op_type_(op_type) { + } + + std::string + ToString() const override { + std::string opTypeString; + + return fmt::format( + "CompareExpr:[Left Field ID: {}, Right Field ID: {}, Left Data " + "Type: {}, " + "Operator: {}, Right " + "Data Type: {}]", + left_field_id_.get(), + right_field_id_.get(), + milvus::proto::plan::OpType_Name(op_type_), + left_data_type_, + right_data_type_); + } + + public: + const FieldId left_field_id_; + const FieldId right_field_id_; + const DataType left_data_type_; + const DataType right_data_type_; + const proto::plan::OpType op_type_; +}; + +class JsonContainsExpr : public ITypeFilterExpr { + public: + JsonContainsExpr(ColumnInfo column, + ContainsType op, + const bool same_type, + const std::vector& vals) + : column_(column), + op_(op), + same_type_(same_type), + vals_(std::move(vals)) { + } + + std::string + ToString() const override { + std::string values; + for (const auto& val : vals_) { + values += val.DebugString() + ", "; + } + return fmt::format( + "JsonContainsExpr:[Column: {}, Operator: {}, Same Type: {}, " + "Values: [{}]]", + column_.ToString(), + JSONContainsExpr_JSONOp_Name(op_), + (same_type_ ? "true" : "false"), + values); + } + + public: + const ColumnInfo column_; + ContainsType op_; + bool same_type_; + const std::vector vals_; +}; +} // namespace expr +} // namespace milvus + +template <> +struct fmt::formatter + : formatter { + auto + format(milvus::proto::plan::ArithOpType c, format_context& ctx) const { + using namespace milvus::proto::plan; + string_view name = "unknown"; + switch (c) { + case ArithOpType::Unknown: + name = "Unknown"; + break; + case ArithOpType::Add: + name = "Add"; + break; + case ArithOpType::Sub: + name = "Sub"; + break; + case ArithOpType::Mul: + name = "Mul"; + break; + case ArithOpType::Div: + name = "Div"; + break; + case ArithOpType::Mod: + name = "Mod"; + break; + case ArithOpType::ArrayLength: + name = "ArrayLength"; + break; + case ArithOpType::ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_: + name = "ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_"; + break; + case ArithOpType::ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_: + name = "ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_"; + break; + } + return formatter::format(name, ctx); + } +}; diff --git a/internal/core/src/index/IndexInfo.h b/internal/core/src/index/IndexInfo.h index 44e9306bed..f925de1e4a 100644 --- a/internal/core/src/index/IndexInfo.h +++ b/internal/core/src/index/IndexInfo.h @@ -13,7 +13,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "common/Types.h" diff --git a/internal/core/src/index/ScalarIndexSort.cpp b/internal/core/src/index/ScalarIndexSort.cpp index 766c1f1e77..a96642e591 100644 --- a/internal/core/src/index/ScalarIndexSort.cpp +++ b/internal/core/src/index/ScalarIndexSort.cpp @@ -68,7 +68,7 @@ ScalarIndexSort::BuildV2(const Config& config) { PanicInfo(S3Error, "failed to create scan iterator"); } auto reader = res.value(); - std::vector field_datas; + std::vector field_datas; for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) { if (!rec.ok()) { PanicInfo(DataFormatBroken, "failed to read data"); @@ -280,7 +280,7 @@ ScalarIndexSort::LoadV2(const Config& config) { index_files.push_back(b.name); } } - std::map index_datas{}; + std::map index_datas{}; for (auto& file_name : index_files) { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index 71a3f42999..984dd623a9 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -24,11 +24,12 @@ #include "common/Types.h" #include "common/EasyAssert.h" +#include "common/Exception.h" +#include "common/Utils.h" +#include "common/Slice.h" #include "index/StringIndexMarisa.h" #include "index/Utils.h" #include "index/Index.h" -#include "common/Utils.h" -#include "common/Slice.h" #include "storage/Util.h" #include "storage/space.h" @@ -73,7 +74,7 @@ StringIndexMarisa::BuildV2(const Config& config) { PanicInfo(S3Error, "failed to create scan iterator"); } auto reader = res.value(); - std::vector field_datas; + std::vector field_datas; for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) { if (!rec.ok()) { PanicInfo(DataFormatBroken, "failed to read data"); @@ -315,7 +316,7 @@ StringIndexMarisa::LoadV2(const Config& config) { index_files.push_back(b.name); } } - std::map index_datas{}; + std::map index_datas{}; for (auto& file_name : index_files) { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { diff --git a/internal/core/src/index/Utils.cpp b/internal/core/src/index/Utils.cpp index 8193241a96..b1d889e171 100644 --- a/internal/core/src/index/Utils.cpp +++ b/internal/core/src/index/Utils.cpp @@ -24,17 +24,18 @@ #include #include #include +#include +#include +#include "common/EasyAssert.h" +#include "common/Exception.h" +#include "common/File.h" +#include "common/FieldData.h" +#include "common/Slice.h" #include "index/Utils.h" #include "index/Meta.h" -#include -#include -#include "common/EasyAssert.h" -#include "knowhere/comp/index_param.h" -#include "common/Slice.h" -#include "storage/FieldData.h" #include "storage/Util.h" -#include "common/File.h" +#include "knowhere/comp/index_param.h" namespace milvus::index { @@ -205,7 +206,7 @@ ParseConfigFromIndexParams( } void -AssembleIndexDatas(std::map& index_datas) { +AssembleIndexDatas(std::map& index_datas) { if (index_datas.find(INDEX_FILE_SLICE_META) != index_datas.end()) { auto slice_meta = index_datas.at(INDEX_FILE_SLICE_META); Config meta_data = Config::parse(std::string( @@ -237,9 +238,8 @@ AssembleIndexDatas(std::map& index_datas) { } void -AssembleIndexDatas( - std::map& index_datas, - std::unordered_map& result) { +AssembleIndexDatas(std::map& index_datas, + std::unordered_map& result) { if (auto meta_iter = index_datas.find(INDEX_FILE_SLICE_META); meta_iter != index_datas.end()) { auto raw_metadata_array = diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index adc0b34595..53670dcba2 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -28,9 +28,9 @@ #include #include "common/Types.h" +#include "common/FieldData.h" #include "index/IndexInfo.h" #include "storage/Types.h" -#include "storage/FieldData.h" namespace milvus::index { @@ -114,12 +114,11 @@ ParseConfigFromIndexParams( const std::map& index_params); void -AssembleIndexDatas(std::map& index_datas); +AssembleIndexDatas(std::map& index_datas); void -AssembleIndexDatas( - std::map& index_datas, - std::unordered_map& result); +AssembleIndexDatas(std::map& index_datas, + std::unordered_map& result); // On Linux, read() (and similar system calls) will transfer at most 0x7ffff000 (2,147,479,552) bytes once void diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 88dac19c6d..16b8ec3eb8 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -38,20 +38,20 @@ #include "knowhere/factory.h" #include "knowhere/comp/time_recorder.h" #include "common/BitsetView.h" -#include "common/Slice.h" #include "common/Consts.h" +#include "common/FieldData.h" +#include "common/File.h" +#include "common/Slice.h" +#include "common/Tracer.h" #include "common/RangeSearchHelper.h" #include "common/Utils.h" #include "log/Log.h" #include "mmap/Types.h" #include "storage/DataCodec.h" -#include "storage/FieldData.h" #include "storage/MemFileManagerImpl.h" #include "storage/ThreadPools.h" -#include "storage/Util.h" -#include "common/File.h" -#include "common/Tracer.h" #include "storage/space.h" +#include "storage/Util.h" namespace milvus::index { @@ -189,7 +189,7 @@ VectorMemIndex::LoadV2(const Config& config) { auto slice_meta_file = index_prefix + "/" + INDEX_FILE_SLICE_META; auto res = space_->GetBlobByteSize(std::string(slice_meta_file)); - std::map index_datas{}; + std::map index_datas{}; if (!res.ok() && !res.status().IsFileNotFound()) { PanicInfo(DataFormatBroken, "failed to read blob"); @@ -289,7 +289,7 @@ VectorMemIndex::Load(const Config& config) { auto parallel_degree = static_cast(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE); - std::map index_datas{}; + std::map index_datas{}; // try to read slice meta first std::string slice_meta_filepath; @@ -414,7 +414,7 @@ VectorMemIndex::BuildV2(const Config& config) { } auto reader = res.value(); - std::vector field_datas; + std::vector field_datas; for (auto rec : *reader) { if (!rec.ok()) { PanicInfo(IndexBuildError, diff --git a/internal/core/src/log/Log.h b/internal/core/src/log/Log.h index 171c542264..306b998b05 100644 --- a/internal/core/src/log/Log.h +++ b/internal/core/src/log/Log.h @@ -53,8 +53,13 @@ __FUNCTION__, \ GetThreadName().c_str()) -#define LOG_SEGCORE_TRACE_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION -#define LOG_SEGCORE_DEBUG_ DLOG(INFO) << SEGCORE_MODULE_FUNCTION +// GLOG has no debug and trace level, +// Using VLOG to implement it. +#define GLOG_DEBUG 5 +#define GLOG_TRACE 6 + +#define LOG_SEGCORE_TRACE_ VLOG(GLOG_TRACE) << SEGCORE_MODULE_FUNCTION +#define LOG_SEGCORE_DEBUG_ VLOG(GLOG_DEBUG) << SEGCORE_MODULE_FUNCTION #define LOG_SEGCORE_INFO_ LOG(INFO) << SEGCORE_MODULE_FUNCTION #define LOG_SEGCORE_WARNING_ LOG(WARNING) << SEGCORE_MODULE_FUNCTION #define LOG_SEGCORE_ERROR_ LOG(ERROR) << SEGCORE_MODULE_FUNCTION diff --git a/internal/core/src/mmap/Column.h b/internal/core/src/mmap/Column.h index 451d86367e..4a5e61c912 100644 --- a/internal/core/src/mmap/Column.h +++ b/internal/core/src/mmap/Column.h @@ -21,15 +21,15 @@ #include #include -#include "common/FieldMeta.h" -#include "common/Span.h" +#include "common/Array.h" #include "common/EasyAssert.h" #include "common/File.h" +#include "common/FieldMeta.h" +#include "common/FieldData.h" +#include "common/Span.h" #include "fmt/format.h" #include "log/Log.h" #include "mmap/Utils.h" -#include "storage/FieldData.h" -#include "common/Array.h" namespace milvus { @@ -156,7 +156,7 @@ class ColumnBase { Span() const = 0; void - AppendBatch(const storage::FieldDataPtr& data) { + AppendBatch(const FieldDataPtr& data) { size_t required_size = size_ + data->Size(); if (required_size > cap_size_) { Expand(required_size * 2 + padding_); diff --git a/internal/core/src/mmap/Types.h b/internal/core/src/mmap/Types.h index fc79b95dd3..c2f8c1a9e4 100644 --- a/internal/core/src/mmap/Types.h +++ b/internal/core/src/mmap/Types.h @@ -19,13 +19,13 @@ #include #include #include -#include "storage/FieldData.h" +#include "common/FieldData.h" namespace milvus { struct FieldDataInfo { FieldDataInfo() { - channel = std::make_shared(); + channel = std::make_shared(); } FieldDataInfo(int64_t field_id, @@ -34,12 +34,12 @@ struct FieldDataInfo { : field_id(field_id), row_count(row_count), mmap_dir_path(std::move(mmap_dir_path)) { - channel = std::make_shared(); + channel = std::make_shared(); } FieldDataInfo(int64_t field_id, size_t row_count, - storage::FieldDataChannelPtr channel) + FieldDataChannelPtr channel) : field_id(field_id), row_count(row_count), channel(std::move(channel)) { @@ -48,7 +48,7 @@ struct FieldDataInfo { FieldDataInfo(int64_t field_id, size_t row_count, std::string mmap_dir_path, - storage::FieldDataChannelPtr channel) + FieldDataChannelPtr channel) : field_id(field_id), row_count(row_count), mmap_dir_path(std::move(mmap_dir_path)), @@ -57,9 +57,9 @@ struct FieldDataInfo { FieldDataInfo(int64_t field_id, size_t row_count, - const std::vector& batch) + const std::vector& batch) : field_id(field_id), row_count(row_count) { - channel = std::make_shared(); + channel = std::make_shared(); for (auto& data : batch) { channel->push(data); } @@ -69,11 +69,11 @@ struct FieldDataInfo { FieldDataInfo(int64_t field_id, size_t row_count, std::string mmap_dir_path, - const std::vector& batch) + const std::vector& batch) : field_id(field_id), row_count(row_count), mmap_dir_path(std::move(mmap_dir_path)) { - channel = std::make_shared(); + channel = std::make_shared(); for (auto& data : batch) { channel->push(data); } @@ -83,6 +83,6 @@ struct FieldDataInfo { int64_t field_id; size_t row_count; std::string mmap_dir_path; - storage::FieldDataChannelPtr channel; + FieldDataChannelPtr channel; }; } // namespace milvus diff --git a/internal/core/src/mmap/Utils.h b/internal/core/src/mmap/Utils.h index e3b718e766..d581823946 100644 --- a/internal/core/src/mmap/Utils.h +++ b/internal/core/src/mmap/Utils.h @@ -32,7 +32,7 @@ namespace milvus { inline size_t -GetDataSize(const std::vector& datas) { +GetDataSize(const std::vector& datas) { size_t total_size{0}; for (auto data : datas) { total_size += data->Size(); @@ -42,7 +42,7 @@ GetDataSize(const std::vector& datas) { } inline void* -FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) { +FillField(DataType data_type, const FieldDataPtr data, void* dst) { char* dest = reinterpret_cast(dst); if (datatype_is_variable(data_type)) { switch (data_type) { @@ -80,7 +80,7 @@ FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) { inline size_t WriteFieldData(File& file, DataType data_type, - const storage::FieldDataPtr& data, + const FieldDataPtr& data, std::vector>& element_indices) { size_t total_written{0}; if (datatype_is_variable(data_type)) { diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h new file mode 100644 index 0000000000..f149b834bd --- /dev/null +++ b/internal/core/src/plan/PlanNode.h @@ -0,0 +1,287 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common/Types.h" +#include "common/Vector.h" +#include "expr/ITypeExpr.h" +#include "common/EasyAssert.h" +#include "segcore/SegmentInterface.h" + +namespace milvus { +namespace plan { + +typedef std::string PlanNodeId; +/** + * @brief Base class for all logic plan node + * + */ +class PlanNode { + public: + explicit PlanNode(const PlanNodeId& id) : id_(id) { + } + + virtual ~PlanNode() = default; + + const PlanNodeId& + id() const { + return id_; + } + + virtual DataType + output_type() const = 0; + + virtual std::vector> + sources() const = 0; + + virtual bool + RequireSplits() const { + return false; + } + + virtual std::string + ToString() const = 0; + + virtual std::string_view + name() const = 0; + + private: + PlanNodeId id_; +}; + +using PlanNodePtr = std::shared_ptr; + +class SegmentNode : public PlanNode { + public: + SegmentNode( + const PlanNodeId& id, + const std::shared_ptr& + segment) + : PlanNode(id), segment_(segment) { + } + + DataType + output_type() const override { + return DataType::ROW; + } + + std::vector> + sources() const override { + return {}; + } + + std::string_view + name() const override { + return "SegmentNode"; + } + + std::string + ToString() const override { + return "SegmentNode"; + } + + private: + std::shared_ptr segment_; +}; + +class ValuesNode : public PlanNode { + public: + ValuesNode(const PlanNodeId& id, + const std::vector& values, + bool parallelizeable = false) + : PlanNode(id), + values_{std::move(values)}, + output_type_(values[0]->type()) { + AssertInfo(!values.empty(), "ValueNode must has value"); + } + + ValuesNode(const PlanNodeId& id, + std::vector&& values, + bool parallelizeable = false) + : PlanNode(id), + values_{std::move(values)}, + output_type_(values[0]->type()) { + AssertInfo(!values.empty(), "ValueNode must has value"); + } + + DataType + output_type() const override { + return output_type_; + } + + const std::vector& + values() const { + return values_; + } + + std::vector + sources() const override { + return {}; + } + + bool + parallelizable() { + return parallelizable_; + } + + std::string_view + name() const override { + return "Values"; + } + + std::string + ToString() const override { + return "Values"; + } + + private: + DataType output_type_; + const std::vector values_; + bool parallelizable_; +}; + +class FilterNode : public PlanNode { + public: + FilterNode(const PlanNodeId& id, + expr::TypedExprPtr filter, + std::vector sources) + : PlanNode(id), + sources_{std::move(sources)}, + filter_(std::move(filter)) { + AssertInfo( + filter_->type() == DataType::BOOL, + fmt::format("Filter expression must be of type BOOLEAN, Got {}", + filter_->type())); + } + + DataType + output_type() const override { + return sources_[0]->output_type(); + } + + std::vector + sources() const override { + return sources_; + } + + const expr::TypedExprPtr& + filter() const { + return filter_; + } + + std::string_view + name() const override { + return "Filter"; + } + + std::string + ToString() const override { + return ""; + } + + private: + const std::vector sources_; + const expr::TypedExprPtr filter_; +}; + +class FilterBitsNode : public PlanNode { + public: + FilterBitsNode( + const PlanNodeId& id, + expr::TypedExprPtr filter, + std::vector sources = std::vector{}) + : PlanNode(id), + sources_{std::move(sources)}, + filter_(std::move(filter)) { + AssertInfo( + filter_->type() == DataType::BOOL, + fmt::format("Filter expression must be of type BOOLEAN, Got {}", + filter_->type())); + } + + DataType + output_type() const override { + return DataType::BOOL; + } + + std::vector + sources() const override { + return sources_; + } + + const expr::TypedExprPtr& + filter() const { + return filter_; + } + + std::string_view + name() const override { + return "FilterBits"; + } + + std::string + ToString() const override { + return fmt::format("FilterBitsNode:[filter_expr:{}]", + filter_->ToString()); + } + + private: + const std::vector sources_; + const expr::TypedExprPtr filter_; +}; + +enum class ExecutionStrategy { + // Process splits as they come in any available driver. + kUngrouped, + // Process splits from each split group only in one driver. + // It is used when split groups represent separate partitions of the data on + // the grouping keys or join keys. In that case it is sufficient to keep only + // the keys from a single split group in a hash table used by group-by or + // join. + kGrouped, +}; +struct PlanFragment { + std::shared_ptr plan_node_; + ExecutionStrategy execution_strategy_{ExecutionStrategy::kUngrouped}; + int32_t num_splitgroups_{0}; + + PlanFragment() = default; + + inline bool + IsGroupedExecution() const { + return execution_strategy_ == ExecutionStrategy::kGrouped; + } + + explicit PlanFragment(std::shared_ptr top_node, + ExecutionStrategy strategy, + int32_t num_splitgroups) + : plan_node_(std::move(top_node)), + execution_strategy_(strategy), + num_splitgroups_(num_splitgroups) { + } + + explicit PlanFragment(std::shared_ptr top_node) + : plan_node_(std::move(top_node)) { + } +}; + +} // namespace plan +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index 18f7af49e5..3b32868de1 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -20,10 +20,12 @@ #include "common/QueryInfo.h" #include "query/Expr.h" +namespace milvus::plan { +class PlanNode; +}; namespace milvus::query { class PlanNodeVisitor; - // Base of all Nodes struct PlanNode { public: @@ -36,6 +38,7 @@ using PlanNodePtr = std::unique_ptr; struct VectorPlanNode : PlanNode { std::optional predicate_; + std::optional> filter_plannode_; SearchInfo search_info_; std::string placeholder_tag_; }; @@ -64,6 +67,7 @@ struct RetrievePlanNode : PlanNode { accept(PlanNodeVisitor&) override; std::optional predicate_; + std::optional> filter_plannode_; bool is_count_; int64_t limit_; }; diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 021fece0ce..811819d46d 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -185,6 +185,12 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { } }(); + auto expr_parser = [&]() -> plan::PlanNodePtr { + auto expr = ParseExprs(anns_proto.predicates()); + return std::make_shared(DEFAULT_PLANNODE_ID, + expr); + }; + auto& query_info_proto = anns_proto.query_info(); SearchInfo search_info; @@ -210,6 +216,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { }(); plan_node->placeholder_tag_ = anns_proto.placeholder_tag(); plan_node->predicate_ = std::move(expr_opt); + if (anns_proto.has_predicates()) { + plan_node->filter_plannode_ = std::move(expr_parser()); + } plan_node->search_info_ = std::move(search_info); return plan_node; } @@ -227,7 +236,13 @@ ProtoParser::RetrievePlanNodeFromProto( auto expr_opt = [&]() -> ExprPtr { return ParseExpr(predicate_proto); }(); + auto expr_parser = [&]() -> plan::PlanNodePtr { + auto expr = ParseExprs(predicate_proto); + return std::make_shared( + DEFAULT_PLANNODE_ID, expr); + }(); node->predicate_ = std::move(expr_opt); + node->filter_plannode_ = std::move(expr_parser); } else { auto& query = plan_node_proto.query(); if (query.has_predicates()) { @@ -235,7 +250,13 @@ ProtoParser::RetrievePlanNodeFromProto( auto expr_opt = [&]() -> ExprPtr { return ParseExpr(predicate_proto); }(); + auto expr_parser = [&]() -> plan::PlanNodePtr { + auto expr = ParseExprs(predicate_proto); + return std::make_shared( + DEFAULT_PLANNODE_ID, expr); + }(); node->predicate_ = std::move(expr_opt); + node->filter_plannode_ = std::move(expr_parser); } node->is_count_ = query.is_count(); node->limit_ = query.limit(); @@ -284,6 +305,16 @@ ProtoParser::CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto) { return retrieve_plan; } +expr::TypedExprPtr +ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) { + auto& column_info = expr_pb.column_info(); + auto field_id = FieldId(column_info.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == static_cast(column_info.data_type())); + return std::make_shared( + expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value()); +} + ExprPtr ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) { auto& column_info = expr_pb.column_info(); @@ -352,6 +383,21 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) { return result; } +expr::TypedExprPtr +ProtoParser::ParseBinaryRangeExprs( + const proto::plan::BinaryRangeExpr& expr_pb) { + auto& columnInfo = expr_pb.column_info(); + auto field_id = FieldId(columnInfo.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == (DataType)columnInfo.data_type()); + return std::make_shared( + columnInfo, + expr_pb.lower_value(), + expr_pb.upper_value(), + expr_pb.lower_inclusive(), + expr_pb.upper_inclusive()); +} + ExprPtr ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) { auto& columnInfo = expr_pb.column_info(); @@ -436,6 +482,27 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) { return result; } +expr::TypedExprPtr +ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) { + auto& left_column_info = expr_pb.left_column_info(); + auto left_field_id = FieldId(left_column_info.field_id()); + auto left_data_type = schema[left_field_id].get_data_type(); + Assert(left_data_type == + static_cast(left_column_info.data_type())); + + auto& right_column_info = expr_pb.right_column_info(); + auto right_field_id = FieldId(right_column_info.field_id()); + auto right_data_type = schema[right_field_id].get_data_type(); + Assert(right_data_type == + static_cast(right_column_info.data_type())); + + return std::make_shared(left_field_id, + right_field_id, + left_data_type, + right_data_type, + expr_pb.op()); +} + ExprPtr ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) { auto& left_column_info = expr_pb.left_column_info(); @@ -461,6 +528,20 @@ ProtoParser::ParseCompareExpr(const proto::plan::CompareExpr& expr_pb) { }(); } +expr::TypedExprPtr +ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) { + auto& columnInfo = expr_pb.column_info(); + auto field_id = FieldId(columnInfo.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == (DataType)columnInfo.data_type()); + std::vector<::milvus::proto::plan::GenericValue> values; + for (size_t i = 0; i < expr_pb.values_size(); i++) { + values.emplace_back(expr_pb.values(i)); + } + return std::make_shared( + columnInfo, values, expr_pb.is_in_field()); +} + ExprPtr ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) { auto& columnInfo = expr_pb.column_info(); @@ -568,6 +649,14 @@ ProtoParser::ParseUnaryExpr(const proto::plan::UnaryExpr& expr_pb) { return std::make_unique(op, expr); } +expr::TypedExprPtr +ProtoParser::ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb) { + auto op = static_cast(expr_pb.op()); + Assert(op == expr::LogicalUnaryExpr::OpType::LogicalNot); + auto child_expr = this->ParseExprs(expr_pb.child()); + return std::make_shared(op, child_expr); +} + ExprPtr ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) { auto op = static_cast(expr_pb.op()); @@ -576,6 +665,14 @@ ProtoParser::ParseBinaryExpr(const proto::plan::BinaryExpr& expr_pb) { return std::make_unique(op, left_expr, right_expr); } +expr::TypedExprPtr +ProtoParser::ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb) { + auto op = static_cast(expr_pb.op()); + auto left_expr = this->ParseExprs(expr_pb.left()); + auto right_expr = this->ParseExprs(expr_pb.right()); + return std::make_shared(op, left_expr, right_expr); +} + ExprPtr ProtoParser::ParseBinaryArithOpEvalRangeExpr( const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) { @@ -642,11 +739,35 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr( return result; } +expr::TypedExprPtr +ProtoParser::ParseBinaryArithOpEvalRangeExprs( + const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) { + auto& column_info = expr_pb.column_info(); + auto field_id = FieldId(column_info.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == static_cast(column_info.data_type())); + return std::make_shared( + column_info, + expr_pb.op(), + expr_pb.arith_op(), + expr_pb.value(), + expr_pb.right_operand()); +} + std::unique_ptr ExtractExistsExprImpl(const proto::plan::ExistsExpr& expr_proto) { return std::make_unique(expr_proto.info()); } +expr::TypedExprPtr +ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) { + auto& column_info = expr_pb.info(); + auto field_id = FieldId(column_info.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == static_cast(column_info.data_type())); + return std::make_shared(column_info); +} + ExprPtr ProtoParser::ParseExistExpr(const proto::plan::ExistsExpr& expr_pb) { auto& column_info = expr_pb.info(); @@ -718,6 +839,24 @@ ExtractJsonContainsExprImpl(const proto::plan::JSONContainsExpr& expr_proto) { val_case); } +expr::TypedExprPtr +ProtoParser::ParseJsonContainsExprs( + const proto::plan::JSONContainsExpr& expr_pb) { + auto& columnInfo = expr_pb.column_info(); + auto field_id = FieldId(columnInfo.field_id()); + auto data_type = schema[field_id].get_data_type(); + Assert(data_type == (DataType)columnInfo.data_type()); + std::vector<::milvus::proto::plan::GenericValue> values; + for (size_t i = 0; i < expr_pb.elements_size(); i++) { + values.emplace_back(expr_pb.elements(i)); + } + return std::make_shared( + columnInfo, + expr_pb.op(), + expr_pb.elements_same_type(), + std::move(values)); +} + ExprPtr ProtoParser::ParseJsonContainsExpr( const proto::plan::JSONContainsExpr& expr_pb) { @@ -755,6 +894,55 @@ ProtoParser::ParseJsonContainsExpr( return result; } +expr::TypedExprPtr +ProtoParser::CreateAlwaysTrueExprs() { + return std::make_shared(); +} + +expr::TypedExprPtr +ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb) { + using ppe = proto::plan::Expr; + switch (expr_pb.expr_case()) { + case ppe::kUnaryRangeExpr: { + return ParseUnaryRangeExprs(expr_pb.unary_range_expr()); + } + case ppe::kBinaryExpr: { + return ParseBinaryExprs(expr_pb.binary_expr()); + } + case ppe::kUnaryExpr: { + return ParseUnaryExprs(expr_pb.unary_expr()); + } + case ppe::kTermExpr: { + return ParseTermExprs(expr_pb.term_expr()); + } + case ppe::kBinaryRangeExpr: { + return ParseBinaryRangeExprs(expr_pb.binary_range_expr()); + } + case ppe::kCompareExpr: { + return ParseCompareExprs(expr_pb.compare_expr()); + } + case ppe::kBinaryArithOpEvalRangeExpr: { + return ParseBinaryArithOpEvalRangeExprs( + expr_pb.binary_arith_op_eval_range_expr()); + } + case ppe::kExistsExpr: { + return ParseExistExprs(expr_pb.exists_expr()); + } + case ppe::kAlwaysTrueExpr: { + return CreateAlwaysTrueExprs(); + } + case ppe::kJsonContainsExpr: { + return ParseJsonContainsExprs(expr_pb.json_contains_expr()); + } + default: { + std::string s; + google::protobuf::TextFormat::PrintToString(expr_pb, &s); + PanicInfo(ExprInvalid, + std::string("unsupported expr proto node: ") + s); + } + } +} + ExprPtr ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) { using ppe = proto::plan::Expr; diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index 806ff62d60..51843d9c57 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -18,6 +18,7 @@ #include "PlanNode.h" #include "common/Schema.h" #include "pb/plan.pb.h" +#include "plan/PlanNode.h" namespace milvus::query { @@ -72,6 +73,40 @@ class ProtoParser { std::unique_ptr CreateRetrievePlan(const proto::plan::PlanNode& plan_node_proto); + expr::TypedExprPtr + ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb); + + expr::TypedExprPtr + ParseExprs(const proto::plan::Expr& expr_pb); + + expr::TypedExprPtr + ParseBinaryArithOpEvalRangeExprs( + const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb); + + expr::TypedExprPtr + ParseBinaryRangeExprs(const proto::plan::BinaryRangeExpr& expr_pb); + + expr::TypedExprPtr + ParseCompareExprs(const proto::plan::CompareExpr& expr_pb); + + expr::TypedExprPtr + ParseTermExprs(const proto::plan::TermExpr& expr_pb); + + expr::TypedExprPtr + ParseUnaryExprs(const proto::plan::UnaryExpr& expr_pb); + + expr::TypedExprPtr + ParseBinaryExprs(const proto::plan::BinaryExpr& expr_pb); + + expr::TypedExprPtr + ParseExistExprs(const proto::plan::ExistsExpr& expr_pb); + + expr::TypedExprPtr + ParseJsonContainsExprs(const proto::plan::JSONContainsExpr& expr_pb); + + expr::TypedExprPtr + CreateAlwaysTrueExprs(); + private: const Schema& schema; }; diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 8e7ba5170c..49af5a1d18 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -16,6 +16,7 @@ #include "query/Expr.h" #include "common/Utils.h" +#include "simd/hook.h" namespace milvus::query { @@ -70,4 +71,61 @@ inline bool out_of_range(int64_t t) { return gt_ub(t) || lt_lb(t); } + +inline void +AppendOneChunk(BitsetType& result, const bool* chunk_ptr, size_t chunk_len) { + // Append a value once instead of BITSET_BLOCK_BIT_SIZE times. + auto AppendBlock = [&result](const bool* ptr, int n) { + for (int i = 0; i < n; ++i) { +#if defined(USE_DYNAMIC_SIMD) + auto val = milvus::simd::get_bitset_block(ptr); +#else + BitsetBlockType val = 0; + // This can use CPU SIMD optimzation + uint8_t vals[BITSET_BLOCK_SIZE] = {0}; + for (size_t j = 0; j < 8; ++j) { + for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) { + vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j; + } + } + for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { + val |= BitsetBlockType(vals[j]) << (8 * j); + } +#endif + result.append(val); + ptr += BITSET_BLOCK_SIZE * 8; + } + }; + // Append bit for these bits that can not be union as a block + // Usually n less than BITSET_BLOCK_BIT_SIZE. + auto AppendBit = [&result](const bool* ptr, int n) { + for (int i = 0; i < n; ++i) { + bool bit = *ptr++; + result.push_back(bit); + } + }; + + size_t res_len = result.size(); + + int n_prefix = + res_len % BITSET_BLOCK_BIT_SIZE == 0 + ? 0 + : std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE, + chunk_len); + + AppendBit(chunk_ptr, n_prefix); + + if (n_prefix == chunk_len) + return; + + size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE; + size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE; + + AppendBlock(chunk_ptr + n_prefix, n_block); + + AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix); + + return; +} + } // namespace milvus::query diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index cd1aa91ce1..5fcffba88f 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -19,6 +19,7 @@ #include "PlanNodeVisitor.h" namespace milvus::query { + class ExecPlanNodeVisitor : public PlanNodeVisitor { public: void @@ -96,6 +97,24 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { return expr_use_pk_index_; } + void + ExecuteExprNodeInternal( + const std::shared_ptr& plannode, + const milvus::segcore::SegmentInternalInterface* segment, + BitsetType& result, + bool& cache_offset_getted, + std::vector& cache_offset); + + void + ExecuteExprNode(const std::shared_ptr& plannode, + const milvus::segcore::SegmentInternalInterface* segment, + BitsetType& result) { + bool get_cache_offset; + std::vector cache_offsets; + ExecuteExprNodeInternal( + plannode, segment, result, get_cache_offset, cache_offsets); + } + private: template void diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 2b1018477f..2df2039c93 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -16,9 +16,12 @@ #include "query/PlanImpl.h" #include "query/SubSearchResult.h" #include "query/generated/ExecExprVisitor.h" +#include "query/Utils.h" #include "segcore/SegmentGrowing.h" #include "common/Json.h" #include "log/Log.h" +#include "plan/PlanNode.h" +#include "exec/Task.h" namespace milvus::query { @@ -73,6 +76,63 @@ empty_search_result(int64_t num_queries, SearchInfo& search_info) { return final_result; } +void +ExecPlanNodeVisitor::ExecuteExprNodeInternal( + const std::shared_ptr& plannode, + const milvus::segcore::SegmentInternalInterface* segment, + BitsetType& bitset_holder, + bool& cache_offset_getted, + std::vector& cache_offset) { + bitset_holder.clear(); + LOG_SEGCORE_INFO_ << "plannode:" << plannode->ToString(); + auto plan = plan::PlanFragment(plannode); + // TODO: get query id from proxy + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, segment, timestamp_); + + auto task = + milvus::exec::Task::Create(DEFAULT_TASK_ID, plan, 0, query_context); + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + AssertInfo(childrens.size() == 1, + "expr result vector's children size not equal one"); + LOG_SEGCORE_DEBUG_ << "output result length:" << childrens[0]->size() + << std::endl; + if (auto vec = std::dynamic_pointer_cast(childrens[0])) { + AppendOneChunk(bitset_holder, + static_cast(vec->GetRawData()), + vec->size()); + } else if (auto row = + std::dynamic_pointer_cast(childrens[0])) { + auto bit_vec = + std::dynamic_pointer_cast(row->child(0)); + AppendOneChunk(bitset_holder, + static_cast(bit_vec->GetRawData()), + bit_vec->size()); + if (!cache_offset_getted) { + // offset cache only get once because not support iterator batch + auto cache_offset_vec = + std::dynamic_pointer_cast(row->child(1)); + auto cache_offset_vec_ptr = + (int64_t*)(cache_offset_vec->GetRawData()); + for (size_t i = 0; i < cache_offset_vec->size(); ++i) { + cache_offset.push_back(cache_offset_vec_ptr[i]); + } + cache_offset_getted = true; + } + } else { + PanicInfo(UnexpectedError, "expr return type not matched"); + } + } + // std::string s; + // boost::to_string(*bitset_holder, s); + // std::cout << bitset_holder->size() << " . " << s << std::endl; +} + template void ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { @@ -98,10 +158,10 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { } std::unique_ptr bitset_holder; - if (node.predicate_.has_value()) { - bitset_holder = std::make_unique( - ExecExprVisitor(*segment, this, active_count, timestamp_) - .call_child(*node.predicate_.value())); + if (node.filter_plannode_.has_value()) { + BitsetType expr_res; + ExecuteExprNode(node.filter_plannode_.value(), segment, expr_res); + bitset_holder = std::make_unique(expr_res); bitset_holder->flip(); } else { bitset_holder = std::make_unique(active_count, false); @@ -165,10 +225,16 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { bitset_holder.resize(active_count); } - if (node.predicate_.has_value() && node.predicate_.value() != nullptr) { - bitset_holder = - ExecExprVisitor(*segment, this, active_count, timestamp_) - .call_child(*(node.predicate_.value())); + // This flag used to indicate whether to get offset from expr module that + // speeds up mvcc filter in the next interface: "timestamp_filter" + bool get_cache_offset = false; + std::vector cache_offsets; + if (node.filter_plannode_.has_value()) { + ExecuteExprNodeInternal(node.filter_plannode_.value(), + segment, + bitset_holder, + get_cache_offset, + cache_offsets); bitset_holder.flip(); } @@ -189,9 +255,8 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { } bool false_filtered_out = false; - if (GetExprUsePkIndex() && IsTermExpr(node.predicate_.value().get())) { - segment->timestamp_filter( - bitset_holder, expr_cached_pk_id_offsets_, timestamp_); + if (get_cache_offset) { + segment->timestamp_filter(bitset_holder, cache_offsets, timestamp_); } else { bitset_holder.flip(); false_filtered_out = true; diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 972e5d2711..d51e47718e 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -42,6 +42,6 @@ set(SEGCORE_FILES SkipIndex.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES}) -target_link_libraries(milvus_segcore milvus_query ${OpenMP_CXX_FLAGS} milvus-storage) +target_link_libraries(milvus_segcore milvus_query milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage) install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index 22dc50e08b..d2134fef55 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -25,13 +25,13 @@ #include #include +#include "common/EasyAssert.h" #include "common/FieldMeta.h" +#include "common/FieldData.h" #include "common/Json.h" #include "common/Span.h" #include "common/Types.h" #include "common/Utils.h" -#include "common/EasyAssert.h" -#include "storage/FieldData.h" namespace milvus::segcore { @@ -103,7 +103,7 @@ class VectorBase { virtual void set_data_raw(ssize_t element_offset, - const std::vector& data) = 0; + const std::vector& data) = 0; void set_data_raw(ssize_t element_offset, @@ -112,7 +112,7 @@ class VectorBase { const FieldMeta& field_meta); virtual void - fill_chunk_data(const std::vector& data) = 0; + fill_chunk_data(const std::vector& data) = 0; virtual SpanBase get_span_base(int64_t chunk_id) const = 0; @@ -197,7 +197,7 @@ class ConcurrentVectorImpl : public VectorBase { } void - fill_chunk_data(const std::vector& datas) + fill_chunk_data(const std::vector& datas) override { // used only for sealed segment AssertInfo(chunks_.size() == 0, "no empty concurrent vector"); @@ -217,7 +217,7 @@ class ConcurrentVectorImpl : public VectorBase { void set_data_raw(ssize_t element_offset, - const std::vector& datas) override { + const std::vector& datas) override { for (auto& field_data : datas) { auto num_rows = field_data->get_num_rows(); set_data_raw(element_offset, field_data->Data(), num_rows); diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index 09613b6040..e5c4279445 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -306,7 +306,7 @@ class IndexingRecord { AppendingIndex(int64_t reserved_offset, int64_t size, FieldId fieldId, - const storage::FieldDataPtr data, + const FieldDataPtr data, const InsertRecord& record) { if (is_in(fieldId)) { auto& indexing = field_indexings_.at(fieldId); diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index b03a09e53e..0178297bfb 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -424,7 +424,7 @@ struct InsertRecord { } void - insert_pks(const std::vector& field_datas) { + insert_pks(const std::vector& field_datas) { std::lock_guard lck(shared_mutex_); int64_t offset = 0; for (auto& data : field_datas) { diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 0be9452704..d201b47238 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -21,6 +21,7 @@ #include "common/Consts.h" #include "common/EasyAssert.h" +#include "common/FieldData.h" #include "common/Types.h" #include "fmt/format.h" #include "log/Log.h" @@ -29,7 +30,6 @@ #include "query/SearchOnSealed.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/Utils.h" -#include "storage/FieldData.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/Util.h" #include "storage/ThreadPools.h" @@ -58,8 +58,12 @@ SegmentGrowingImpl::mask_with_delete(BitsetType& bitset, return; } auto& delete_bitset = *bitmap_holder->bitmap_ptr; - AssertInfo(delete_bitset.size() == bitset.size(), - "Deleted bitmap size not equal to filtered bitmap size"); + AssertInfo( + delete_bitset.size() == bitset.size(), + fmt::format( + "Deleted bitmap size:{} not equal to filtered bitmap size:{}", + delete_bitset.size(), + bitset.size())); bitset |= delete_bitset; } @@ -177,12 +181,12 @@ SegmentGrowingImpl::LoadFieldData(const LoadFieldDataInfo& infos) { for (auto& [id, info] : infos.field_infos) { auto field_id = FieldId(id); auto insert_files = info.insert_files; - auto channel = std::make_shared(); + auto channel = std::make_shared(); auto& pool = ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE); auto load_future = pool.Submit(LoadFieldDatasFromRemote, insert_files, channel); - auto field_data = CollectFieldDataChannel(channel); + auto field_data = storage::CollectFieldDataChannel(channel); if (field_id == TimestampFieldID) { // step 2: sort timestamp // query node already guarantees that the timestamp is ordered, avoid field data copy in c++ @@ -263,7 +267,8 @@ SegmentGrowingImpl::LoadFieldDataV2(const LoadFieldDataInfo& infos) { std::shared_ptr space = std::move(res.value()); auto load_future = pool.Submit( LoadFieldDatasFromRemote2, space, schema_, field_data_info); - auto field_data = CollectFieldDataChannel(field_data_info.channel); + auto field_data = + milvus::storage::CollectFieldDataChannel(field_data_info.channel); if (field_id == TimestampFieldID) { // step 2: sort timestamp // query node already guarantees that the timestamp is ordered, avoid field data copy in c++ diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 3aec50fb97..d7fdb32959 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -235,7 +235,13 @@ class SegmentGrowingImpl : public SegmentGrowing { bool HasIndex(FieldId field_id) const override { - return true; + auto& field_meta = schema_->operator[](field_id); + if (datatype_is_vector(field_meta.get_data_type()) && + indexing_record_.SyncDataWithIndex(field_id)) { + return true; + } + + return false; } bool diff --git a/internal/core/src/segcore/SegmentSealed.h b/internal/core/src/segcore/SegmentSealed.h index 3771646cbf..62baa25613 100644 --- a/internal/core/src/segcore/SegmentSealed.h +++ b/internal/core/src/segcore/SegmentSealed.h @@ -47,6 +47,7 @@ class SegmentSealed : public SegmentInternalInterface { } }; -using SegmentSealedPtr = std::unique_ptr; +using SegmentSealedSPtr = std::shared_ptr; +using SegmentSealedUPtr = std::unique_ptr; } // namespace milvus::segcore diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index bf66fe5c7d..e8e16c5edb 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -33,6 +33,7 @@ #include "mmap/Column.h" #include "common/Consts.h" #include "common/FieldMeta.h" +#include "common/FieldData.h" #include "common/Types.h" #include "log/Log.h" #include "pb/schema.pb.h" @@ -40,7 +41,6 @@ #include "query/ScalarIndex.h" #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" -#include "storage/FieldData.h" #include "storage/Util.h" #include "storage/ThreadPools.h" #include "storage/ChunkCacheSingleton.h" @@ -279,7 +279,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { if (system_field_type == SystemFieldType::Timestamp) { std::vector timestamps(num_rows); int64_t offset = 0; - auto field_data = CollectFieldDataChannel(data.channel); + auto field_data = storage::CollectFieldDataChannel(data.channel); for (auto& data : field_data) { int64_t row_count = data->get_num_rows(); std::copy_n(static_cast(data->Data()), @@ -307,7 +307,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { AssertInfo(system_field_type == SystemFieldType::RowId, "System field type of id column is not RowId"); - auto field_data = CollectFieldDataChannel(data.channel); + auto field_data = storage::CollectFieldDataChannel(data.channel); // write data under lock std::unique_lock lck(mutex_); @@ -335,7 +335,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { auto var_column = std::make_shared>( num_rows, field_meta); - storage::FieldDataPtr field_data; + FieldDataPtr field_data; while (data.channel->pop(field_data)) { for (auto i = 0; i < field_data->get_num_rows(); i++) { auto str = static_cast( @@ -354,7 +354,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { auto var_column = std::make_shared>( num_rows, field_meta); - storage::FieldDataPtr field_data; + FieldDataPtr field_data; while (data.channel->pop(field_data)) { for (auto i = 0; i < field_data->get_num_rows(); i++) { auto padded_string = @@ -374,7 +374,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { case milvus::DataType::ARRAY: { auto var_column = std::make_shared(num_rows, field_meta); - storage::FieldDataPtr field_data; + FieldDataPtr field_data; while (data.channel->pop(field_data)) { for (auto i = 0; i < field_data->get_num_rows(); i++) { auto rawValue = field_data->RawValue(i); @@ -398,7 +398,7 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { field_id, num_rows, field_data_size); } else { column = std::make_shared(num_rows, field_meta); - storage::FieldDataPtr field_data; + FieldDataPtr field_data; while (data.channel->pop(field_data)) { column->AppendBatch(field_data); } @@ -469,7 +469,7 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) { auto data_size = 0; std::vector indices{}; std::vector> element_indices{}; - storage::FieldDataPtr field_data; + FieldDataPtr field_data; while (data.channel->pop(field_data)) { data_size += field_data->Size(); auto written = @@ -669,8 +669,12 @@ SegmentSealedImpl::mask_with_delete(BitsetType& bitset, return; } auto& delete_bitset = *bitmap_holder->bitmap_ptr; - AssertInfo(delete_bitset.size() == bitset.size(), - "Deleted bitmap size not equal to filtered bitmap size"); + AssertInfo( + delete_bitset.size() == bitset.size(), + fmt::format( + "Deleted bitmap size:{} not equal to filtered bitmap size:{}", + delete_bitset.size(), + bitset.size())); bitset |= delete_bitset; } diff --git a/internal/core/src/segcore/SegmentSealedImpl.h b/internal/core/src/segcore/SegmentSealedImpl.h index 8c4ffbfe88..01e6492e1a 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.h +++ b/internal/core/src/segcore/SegmentSealedImpl.h @@ -299,7 +299,7 @@ class SegmentSealedImpl : public SegmentSealed { vec_binlog_config_; }; -inline SegmentSealedPtr +inline SegmentSealedUPtr CreateSealedSegment( SchemaPtr schema, IndexMetaPtr index_meta = nullptr, diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 1d82530243..eb3d4155b6 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -14,14 +14,14 @@ #include #include +#include "common/Common.h" +#include "common/FieldData.h" #include "index/ScalarIndex.h" #include "log/Log.h" -#include "storage/FieldData.h" -#include "storage/RemoteChunkManagerSingleton.h" -#include "common/Common.h" -#include "storage/ThreadPool.h" -#include "storage/Util.h" #include "mmap/Utils.h" +#include "storage/ThreadPool.h" +#include "storage/RemoteChunkManagerSingleton.h" +#include "storage/Util.h" namespace milvus::segcore { @@ -50,7 +50,7 @@ ParsePksFromFieldData(std::vector& pks, const DataArray& data) { void ParsePksFromFieldData(DataType data_type, std::vector& pks, - const std::vector& datas) { + const std::vector& datas) { int64_t offset = 0; for (auto& field_data : datas) { @@ -737,7 +737,7 @@ LoadFieldDatasFromRemote2(std::shared_ptr space, // segcore use default remote chunk manager to load data from minio/s3 void LoadFieldDatasFromRemote(std::vector& remote_files, - storage::FieldDataChannelPtr channel) { + FieldDataChannelPtr channel) { try { auto parallel_degree = static_cast( DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE); diff --git a/internal/core/src/segcore/Utils.h b/internal/core/src/segcore/Utils.h index c128f29205..93089d1639 100644 --- a/internal/core/src/segcore/Utils.h +++ b/internal/core/src/segcore/Utils.h @@ -20,13 +20,13 @@ #include #include +#include "common/FieldData.h" #include "common/QueryResult.h" // #include "common/Schema.h" #include "common/Types.h" +#include "index/Index.h" #include "segcore/DeletedRecord.h" #include "segcore/InsertRecord.h" -#include "index/Index.h" -#include "storage/FieldData.h" #include "storage/space.h" namespace milvus::segcore { @@ -37,7 +37,7 @@ ParsePksFromFieldData(std::vector& pks, const DataArray& data); void ParsePksFromFieldData(DataType data_type, std::vector& pks, - const std::vector& datas); + const std::vector& datas); void ParsePksFromIDs(std::vector& pks, @@ -159,7 +159,7 @@ ReverseDataFromIndex(const index::IndexBase* index, void LoadFieldDatasFromRemote(std::vector& remote_files, - storage::FieldDataChannelPtr channel); + FieldDataChannelPtr channel); void LoadFieldDatasFromRemote2(std::shared_ptr space, diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index fd6e63223c..e997eb5b88 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -10,21 +10,22 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "segcore/segment_c.h" + #include +#include "common/FieldData.h" #include "common/LoadInfo.h" #include "common/Types.h" #include "common/Tracer.h" #include "common/type_c.h" #include "google/protobuf/text_format.h" #include "log/Log.h" +#include "mmap/Types.h" #include "segcore/Collection.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentSealedImpl.h" #include "segcore/Utils.h" -#include "storage/FieldData.h" #include "storage/Util.h" -#include "mmap/Types.h" #include "storage/space.h" ////////////////////////////// common interfaces ////////////////////////////// @@ -292,8 +293,8 @@ LoadFieldRawData(CSegmentInterface c_segment, } auto field_data = milvus::storage::CreateFieldData(data_type, dim); field_data->FillFieldData(data, row_count); - milvus::storage::FieldDataChannelPtr channel = - std::make_shared(); + milvus::FieldDataChannelPtr channel = + std::make_shared(); channel->push(field_data); channel->close(); auto field_data_info = milvus::FieldDataInfo( diff --git a/internal/core/src/simd/CMakeLists.txt b/internal/core/src/simd/CMakeLists.txt index 64106eba5d..ced8277197 100644 --- a/internal/core/src/simd/CMakeLists.txt +++ b/internal/core/src/simd/CMakeLists.txt @@ -28,6 +28,10 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") # TODO: add arm cpu simd + message ("simd using arm mode") + list(APPEND MILVUS_SIMD_SRCS + neon.cpp + ) endif() add_library(milvus_simd ${MILVUS_SIMD_SRCS}) diff --git a/internal/core/src/simd/avx2.cpp b/internal/core/src/simd/avx2.cpp index 0faa120198..08c6a2636d 100644 --- a/internal/core/src/simd/avx2.cpp +++ b/internal/core/src/simd/avx2.cpp @@ -39,7 +39,7 @@ GetBitsetBlockAVX2(const bool* src) { BitsetBlockType res[4]; _mm256_storeu_si256((__m256i*)res, tmpvec); return res[0]; - // __m128i tmpvec = _mm_loadu_si64(tmp); + // __m256i tmpvec = _mm_loadu_si64(tmp); // BitsetBlockType res; // _mm_storeu_si64(&res, tmpvec); // return res; @@ -231,6 +231,80 @@ FindTermAVX2(const double* src, size_t vec_size, double val) { return false; } +bool +AllFalseAVX2(const bool* src, int64_t size) { + int num_chunk = size / 32; + __m256i highbit = _mm256_set1_epi8(0x7F); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m256i data = + _mm256_loadu_si256(reinterpret_cast(src + i)); + __m256i highbits = _mm256_add_epi8(data, highbit); + if (_mm256_movemask_epi8(highbits) != 0) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (src[i]) { + return false; + } + } + return true; +} + +bool +AllTrueAVX2(const bool* src, int64_t size) { + int num_chunk = size / 16; + __m256i highbit = _mm256_set1_epi8(0x7F); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m256i data = + _mm256_loadu_si256(reinterpret_cast(src + i)); + __m256i highbits = _mm256_add_epi8(data, highbit); + if (_mm256_movemask_epi8(highbits) != 0xFFFF) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (!src[i]) { + return false; + } + } + return true; +} + +void +AndBoolAVX2(bool* left, bool* right, int64_t size) { + int num_chunk = size / 32; + for (size_t i = 0; i < num_chunk * 32; i += 32) { + __m256i l_reg = + _mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i)); + __m256i r_reg = + _mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i)); + __m256i res = _mm256_and_si256(l_reg, r_reg); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res); + } + for (size_t i = num_chunk * 32; i < size; ++i) { + left[i] &= right[i]; + } +} + +void +OrBoolAVX2(bool* left, bool* right, int64_t size) { + int num_chunk = size / 32; + for (size_t i = 0; i < num_chunk * 32; i += 32) { + __m256i l_reg = + _mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i)); + __m256i r_reg = + _mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i)); + __m256i res = _mm256_or_si256(l_reg, r_reg); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res); + } + for (size_t i = num_chunk * 32; i < size; ++i) { + left[i] |= right[i]; + } +} + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/avx2.h b/internal/core/src/simd/avx2.h index 7e811aaa2b..90d6c7bbb7 100644 --- a/internal/core/src/simd/avx2.h +++ b/internal/core/src/simd/avx2.h @@ -58,5 +58,17 @@ template <> bool FindTermAVX2(const double* src, size_t vec_size, double val); +bool +AllFalseAVX2(const bool* src, int64_t size); + +bool +AllTrueAVX2(const bool* src, int64_t size); + +void +AndBoolAVX2(bool* left, bool* right, int64_t size); + +void +OrBoolAVX2(bool* left, bool* right, int64_t size); + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/avx512.cpp b/internal/core/src/simd/avx512.cpp index 42a7a08c77..3df38319fd 100644 --- a/internal/core/src/simd/avx512.cpp +++ b/internal/core/src/simd/avx512.cpp @@ -183,6 +183,39 @@ FindTermAVX512(const double* src, size_t vec_size, double val) { } return false; } + +void +AndBoolAVX512(bool* left, bool* right, int64_t size) { + int num_chunk = size / 64; + for (size_t i = 0; i < num_chunk * 64; i += 64) { + __m512i l_reg = + _mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i)); + __m512i r_reg = + _mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i)); + __m512i res = _mm512_and_si512(l_reg, r_reg); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res); + } + for (size_t i = num_chunk * 64; i < size; ++i) { + left[i] &= right[i]; + } +} + +void +OrBoolAVX512(bool* left, bool* right, int64_t size) { + int num_chunk = size / 64; + for (size_t i = 0; i < num_chunk * 64; i += 64) { + __m512i l_reg = + _mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i)); + __m512i r_reg = + _mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i)); + __m512i res = _mm512_or_si512(l_reg, r_reg); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res); + } + for (size_t i = num_chunk * 64; i < size; ++i) { + left[i] |= right[i]; + } +} + } // namespace simd } // namespace milvus #endif diff --git a/internal/core/src/simd/avx512.h b/internal/core/src/simd/avx512.h index f09c2c2116..fe24b00bb6 100644 --- a/internal/core/src/simd/avx512.h +++ b/internal/core/src/simd/avx512.h @@ -55,5 +55,11 @@ template <> bool FindTermAVX512(const double* src, size_t vec_size, double val); +void +AndBoolAVX512(bool* left, bool* right, int64_t size); + +void +OrBoolAVX512(bool* left, bool* right, int64_t size); + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/hook.cpp b/internal/core/src/simd/hook.cpp index 0ae5f24266..185933c18c 100644 --- a/internal/core/src/simd/hook.cpp +++ b/internal/core/src/simd/hook.cpp @@ -25,6 +25,8 @@ #include "sse2.h" #include "sse4.h" #include "instruction_set.h" +#elif defined(__ARM_NEON) +#include "neon.h" #endif namespace milvus { @@ -44,6 +46,12 @@ bool use_find_term_avx512; #endif decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef; +decltype(all_false) all_false = AllFalseRef; +decltype(all_true) all_true = AllTrueRef; +decltype(invert_bool) invert_bool = InvertBoolRef; +decltype(and_bool) and_bool = AndBoolRef; +decltype(or_bool) or_bool = OrBoolRef; + FindTermPtr find_term_bool = FindTermRef; FindTermPtr find_term_int8 = FindTermRef; FindTermPtr find_term_int16 = FindTermRef; @@ -161,9 +169,82 @@ find_term_hook() { LOG_SEGCORE_INFO_ << "find term hook simd type: " << simd_type; } +void +all_boolean_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + if (use_sse2 && cpu_support_sse2()) { + simd_type = "SSE2"; + all_false = AllFalseSSE2; + all_true = AllTrueSSE2; + } +#elif defined(__ARM_NEON) + simd_type = "NEON"; + all_false = AllFalseNEON; + all_true = AllTrueNEON; +#endif + // TODO: support arm cpu + LOG_SEGCORE_INFO_ << "AllFalse/AllTrue hook simd type: " << simd_type; +} + +void +invert_boolean_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + if (use_sse2 && cpu_support_sse2()) { + simd_type = "SSE2"; + invert_bool = InvertBoolSSE2; + } +#elif defined(__ARM_NEON) + simd_type = "NEON"; + invert_bool = InvertBoolNEON; +#endif + // TODO: support arm cpu + LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type; +} + +void +logical_boolean_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + if (use_avx512 && cpu_support_avx512()) { + simd_type = "AVX512"; + and_bool = AndBoolAVX512; + or_bool = OrBoolAVX512; + } else if (use_avx2 && cpu_support_avx2()) { + simd_type = "AVX2"; + and_bool = AndBoolAVX2; + or_bool = OrBoolAVX2; + } else if (use_sse2 && cpu_support_sse2()) { + simd_type = "SSE2"; + and_bool = AndBoolSSE2; + or_bool = OrBoolSSE2; + } +#elif defined(__ARM_NEON) + simd_type = "NEON"; + and_bool = AndBoolNEON; + or_bool = OrBoolNEON; +#endif + // TODO: support arm cpu + LOG_SEGCORE_INFO_ << "InvertBoolean hook simd type: " << simd_type; +} +void +boolean_hook() { + all_boolean_hook(); + invert_boolean_hook(); + logical_boolean_hook(); +} + static int init_hook_ = []() { bitset_hook(); find_term_hook(); + boolean_hook(); return 0; }(); diff --git a/internal/core/src/simd/hook.h b/internal/core/src/simd/hook.h index 050f660a10..98e82853ae 100644 --- a/internal/core/src/simd/hook.h +++ b/internal/core/src/simd/hook.h @@ -19,6 +19,11 @@ namespace milvus { namespace simd { extern BitsetBlockType (*get_bitset_block)(const bool* src); +extern bool (*all_false)(const bool* src, int64_t size); +extern bool (*all_true)(const bool* src, int64_t size); +extern void (*invert_bool)(bool* src, int64_t size); +extern void (*and_bool)(bool* left, bool* right, int64_t size); +extern void (*or_bool)(bool* left, bool* right, int64_t size); template using FindTermPtr = bool (*)(const T* src, size_t size, T val); @@ -63,6 +68,18 @@ bitset_hook(); void find_term_hook(); +void +boolean_hook(); + +void +all_boolean_hook(); + +void +invert_boolean_hook(); + +void +logical_boolean_hook(); + template bool find_term_func(const T* data, size_t size, T val) { diff --git a/internal/core/src/simd/neon.cpp b/internal/core/src/simd/neon.cpp new file mode 100644 index 0000000000..6bdda9138e --- /dev/null +++ b/internal/core/src/simd/neon.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#if defined(__ARM_NEON) + +#include "neon.h" + +#include +#include + +namespace milvus { +namespace simd { + +bool +AllFalseNEON(const bool* src, int64_t size) { + int num_chunk = size / 16; + + const uint8_t* ptr = reinterpret_cast(src); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + uint8x16_t data = vld1q_u8(ptr + i); + if (vmaxvq_u8(data) != 0) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (src[i]) { + return false; + } + } + + return true; +} + +bool +AllTrueNEON(const bool* src, int64_t size) { + int num_chunk = size / 16; + + const uint8_t* ptr = reinterpret_cast(src); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + uint8x16_t data = vld1q_u8(ptr + i); + if (vminvq_u8(data) == 0) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (!src[i]) { + return false; + } + } + + return true; +} + +void +InvertBoolNEON(bool* src, int64_t size) { + int num_chunk = size / 16; + uint8x16_t mask = vdupq_n_u8(0x01); + uint8_t* ptr = reinterpret_cast(src); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + uint8x16_t data = vld1q_u8(ptr + i); + + uint8x16_t flipped = veorq_u8(data, mask); + + vst1q_u8(ptr + i, flipped); + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + src[i] = !src[i]; + } +} + +void +AndBoolNEON(bool* left, bool* right, int64_t size) { + int num_chunk = size / 16; + uint8_t* lptr = reinterpret_cast(left); + uint8_t* rptr = reinterpret_cast(right); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + uint8x16_t l_reg = vld1q_u8(lptr + i); + uint8x16_t r_reg = vld1q_u8(rptr + i); + + uint8x16_t res = vandq_u8(l_reg, r_reg); + + vst1q_u8(lptr + i, res); + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + left[i] &= right[i]; + } +} + +void +OrBoolNEON(bool* left, bool* right, int64_t size) { + int num_chunk = size / 16; + uint8_t* lptr = reinterpret_cast(left); + uint8_t* rptr = reinterpret_cast(right); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + uint8x16_t l_reg = vld1q_u8(lptr + i); + uint8x16_t r_reg = vld1q_u8(rptr + i); + + uint8x16_t res = vorrq_u8(l_reg, r_reg); + + vst1q_u8(lptr + i, res); + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + left[i] |= right[i]; + } +} + +} // namespace simd +} // namespace milvus + +#endif \ No newline at end of file diff --git a/internal/core/src/simd/neon.h b/internal/core/src/simd/neon.h new file mode 100644 index 0000000000..2c38a2eb70 --- /dev/null +++ b/internal/core/src/simd/neon.h @@ -0,0 +1,37 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once +#include +#include "common.h" +namespace milvus { +namespace simd { + +BitsetBlockType +GetBitsetBlockSSE2(const bool* src); + +bool +AllFalseNEON(const bool* src, int64_t size); + +bool +AllTrueNEON(const bool* src, int64_t size); + +void +InvertBoolNEON(bool* src, int64_t size); + +void +AndBoolNEON(bool* left, bool* right, int64_t size); + +void +OrBoolNEON(bool* left, bool* right, int64_t size); + +} // namespace simd +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/simd/ref.cpp b/internal/core/src/simd/ref.cpp index 999bfa0458..f858fe97d2 100644 --- a/internal/core/src/simd/ref.cpp +++ b/internal/core/src/simd/ref.cpp @@ -29,5 +29,46 @@ GetBitsetBlockRef(const bool* src) { return val; } +bool +AllTrueRef(const bool* src, int64_t size) { + for (size_t i = 0; i < size; ++i) { + if (!src[i]) { + return false; + } + } + return true; +} + +bool +AllFalseRef(const bool* src, int64_t size) { + for (size_t i = 0; i < size; ++i) { + if (src[i]) { + return false; + } + } + return true; +} + +void +InvertBoolRef(bool* src, int64_t size) { + for (size_t i = 0; i < size; ++i) { + src[i] = !src[i]; + } +} + +void +AndBoolRef(bool* left, bool* right, int64_t size) { + for (size_t i = 0; i < size; ++i) { + left[i] &= right[i]; + } +} + +void +OrBoolRef(bool* left, bool* right, int64_t size) { + for (size_t i = 0; i < size; ++i) { + left[i] |= right[i]; + } +} + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/ref.h b/internal/core/src/simd/ref.h index 604b0aa7c3..6e90c7215a 100644 --- a/internal/core/src/simd/ref.h +++ b/internal/core/src/simd/ref.h @@ -19,6 +19,21 @@ namespace simd { BitsetBlockType GetBitsetBlockRef(const bool* src); +bool +AllTrueRef(const bool* src, int64_t size); + +bool +AllFalseRef(const bool* src, int64_t size); + +void +InvertBoolRef(bool* src, int64_t size); + +void +AndBoolRef(bool* left, bool* right, int64_t size); + +void +OrBoolRef(bool* left, bool* right, int64_t size); + template bool FindTermRef(const T* src, size_t size, T val) { diff --git a/internal/core/src/simd/sse2.cpp b/internal/core/src/simd/sse2.cpp index e7cb207757..40542bf22b 100644 --- a/internal/core/src/simd/sse2.cpp +++ b/internal/core/src/simd/sse2.cpp @@ -256,6 +256,102 @@ FindTermSSE2(const double* src, size_t vec_size, double val) { return false; } +void +print_m128i(__m128i v) { + alignas(16) int result[4]; + _mm_store_si128(reinterpret_cast<__m128i*>(result), v); + + for (int i = 0; i < 4; ++i) { + std::cout << std::hex << result[i] << " "; + } + + std::cout << std::endl; +} + +bool +AllFalseSSE2(const bool* src, int64_t size) { + int num_chunk = size / 16; + __m128i highbit = _mm_set1_epi8(0x7F); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m128i data = + _mm_loadu_si128(reinterpret_cast(src + i)); + __m128i highbits = _mm_add_epi8(data, highbit); + if (_mm_movemask_epi8(highbits) != 0) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (src[i]) { + return false; + } + } + return true; +} + +bool +AllTrueSSE2(const bool* src, int64_t size) { + int num_chunk = size / 16; + __m128i highbit = _mm_set1_epi8(0x7F); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m128i data = + _mm_loadu_si128(reinterpret_cast(src + i)); + __m128i highbits = _mm_add_epi8(data, highbit); + if (_mm_movemask_epi8(highbits) != 0xFFFF) { + return false; + } + } + + for (size_t i = num_chunk * 16; i < size; ++i) { + if (!src[i]) { + return false; + } + } + return true; +} + +void +InvertBoolSSE2(bool* src, int64_t size) { + int num_chunk = size / 16; + __m128i mask = _mm_set1_epi8(0x01); + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m128i data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + i)); + __m128i flipped = _mm_xor_si128(data, mask); + _mm_storeu_si128(reinterpret_cast<__m128i*>(src + i), flipped); + } + for (size_t i = num_chunk * 16; i < size; ++i) { + src[i] = !src[i]; + } +} + +void +AndBoolSSE2(bool* left, bool* right, int64_t size) { + int num_chunk = size / 16; + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i)); + __m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i)); + __m128i res = _mm_and_si128(l_reg, r_reg); + _mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res); + } + for (size_t i = num_chunk * 16; i < size; ++i) { + left[i] &= right[i]; + } +} + +void +OrBoolSSE2(bool* left, bool* right, int64_t size) { + int num_chunk = size / 16; + for (size_t i = 0; i < num_chunk * 16; i += 16) { + __m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i)); + __m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i)); + __m128i res = _mm_or_si128(l_reg, r_reg); + _mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res); + } + for (size_t i = num_chunk * 16; i < size; ++i) { + left[i] |= right[i]; + } +} + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/sse2.h b/internal/core/src/simd/sse2.h index b7bbde86c0..9b53bb869f 100644 --- a/internal/core/src/simd/sse2.h +++ b/internal/core/src/simd/sse2.h @@ -24,6 +24,21 @@ namespace simd { BitsetBlockType GetBitsetBlockSSE2(const bool* src); +bool +AllFalseSSE2(const bool* src, int64_t size); + +bool +AllTrueSSE2(const bool* src, int64_t size); + +void +InvertBoolSSE2(bool* src, int64_t size); + +void +AndBoolSSE2(bool* left, bool* right, int64_t size); + +void +OrBoolSSE2(bool* left, bool* right, int64_t size); + template bool FindTermSSE2(const T* src, size_t vec_size, T va) { diff --git a/internal/core/src/storage/CMakeLists.txt b/internal/core/src/storage/CMakeLists.txt index 6c4c6d24e9..10f3c4b4c3 100644 --- a/internal/core/src/storage/CMakeLists.txt +++ b/internal/core/src/storage/CMakeLists.txt @@ -38,7 +38,6 @@ set(STORAGE_FILES PayloadStream.cpp DataCodec.cpp Util.cpp - FieldData.cpp PayloadReader.cpp PayloadWriter.cpp BinlogReader.cpp diff --git a/internal/core/src/storage/DataCodec.h b/internal/core/src/storage/DataCodec.h index 7def219eb9..74fe0a65c4 100644 --- a/internal/core/src/storage/DataCodec.h +++ b/internal/core/src/storage/DataCodec.h @@ -20,8 +20,8 @@ #include #include +#include "common/FieldData.h" #include "storage/Types.h" -#include "storage/FieldData.h" #include "storage/PayloadStream.h" #include "storage/BinlogReader.h" diff --git a/internal/core/src/storage/Event.h b/internal/core/src/storage/Event.h index 826da5cfaf..87a5d0eb4d 100644 --- a/internal/core/src/storage/Event.h +++ b/internal/core/src/storage/Event.h @@ -21,9 +21,9 @@ #include #include +#include "common/FieldData.h" #include "common/Types.h" #include "storage/Types.h" -#include "storage/FieldData.h" #include "storage/BinlogReader.h" namespace milvus::storage { diff --git a/internal/core/src/storage/FileManager.h b/internal/core/src/storage/FileManager.h index 81259d1007..c8618ab4b4 100644 --- a/internal/core/src/storage/FileManager.h +++ b/internal/core/src/storage/FileManager.h @@ -20,11 +20,11 @@ #include #include -#include "knowhere/file_manager.h" #include "common/Consts.h" +#include "knowhere/file_manager.h" +#include "log/Log.h" #include "storage/ChunkManager.h" #include "storage/Types.h" -#include "log/Log.h" #include "storage/space.h" namespace milvus::storage { diff --git a/internal/core/src/storage/LocalChunkManager.cpp b/internal/core/src/storage/LocalChunkManager.cpp index 7baca5e6c0..b2b926223d 100644 --- a/internal/core/src/storage/LocalChunkManager.cpp +++ b/internal/core/src/storage/LocalChunkManager.cpp @@ -22,6 +22,7 @@ #include #include "common/EasyAssert.h" +#include "common/Exception.h" #define THROWLOCALERROR(code, FUNCTION) \ do { \ diff --git a/internal/core/src/storage/MemFileManagerImpl.cpp b/internal/core/src/storage/MemFileManagerImpl.cpp index 72cdfac39c..13639d6963 100644 --- a/internal/core/src/storage/MemFileManagerImpl.cpp +++ b/internal/core/src/storage/MemFileManagerImpl.cpp @@ -18,11 +18,11 @@ #include #include -#include "log/Log.h" -#include "storage/FieldData.h" -#include "storage/FileManager.h" -#include "storage/Util.h" #include "common/Common.h" +#include "common/FieldData.h" +#include "log/Log.h" +#include "storage/Util.h" +#include "storage/FileManager.h" namespace milvus::storage { @@ -140,10 +140,10 @@ MemFileManagerImpl::LoadFile(const std::string& filename) noexcept { return true; } -std::map +std::map MemFileManagerImpl::LoadIndexToMemory( const std::vector& remote_files) { - std::map file_to_index_data; + std::map file_to_index_data; auto parallel_degree = static_cast(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE); std::vector batch_files; diff --git a/internal/core/src/storage/MemFileManagerImpl.h b/internal/core/src/storage/MemFileManagerImpl.h index 726e6b28ef..1349cbeb41 100644 --- a/internal/core/src/storage/MemFileManagerImpl.h +++ b/internal/core/src/storage/MemFileManagerImpl.h @@ -54,7 +54,7 @@ class MemFileManagerImpl : public FileManagerImpl { return "MemIndexFileManagerImpl"; } - std::map + std::map LoadIndexToMemory(const std::vector& remote_files); std::vector diff --git a/internal/core/src/storage/MinioChunkManager.cpp b/internal/core/src/storage/MinioChunkManager.cpp index d3ff8d1d8b..a145a6dcdb 100644 --- a/internal/core/src/storage/MinioChunkManager.cpp +++ b/internal/core/src/storage/MinioChunkManager.cpp @@ -14,6 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "storage/MinioChunkManager.h" + #include #include #include @@ -28,7 +30,6 @@ #include #include -#include "storage/MinioChunkManager.h" #include "storage/AliyunSTSClient.h" #include "storage/AliyunCredentialsProvider.h" #include "storage/prometheus_client.h" diff --git a/internal/core/src/storage/MinioChunkManager.h b/internal/core/src/storage/MinioChunkManager.h index 348f2dd902..da839e64a7 100644 --- a/internal/core/src/storage/MinioChunkManager.h +++ b/internal/core/src/storage/MinioChunkManager.h @@ -16,6 +16,11 @@ #pragma once +#include +#include +#include +#include + #include #include #include @@ -25,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -32,13 +38,8 @@ #include #include -#include -#include -#include -#include -#include - #include "common/EasyAssert.h" +#include "common/Exception.h" #include "storage/ChunkManager.h" #include "storage/Types.h" diff --git a/internal/core/src/storage/PayloadReader.h b/internal/core/src/storage/PayloadReader.h index 90e63a20ec..b5fb22084d 100644 --- a/internal/core/src/storage/PayloadReader.h +++ b/internal/core/src/storage/PayloadReader.h @@ -19,8 +19,8 @@ #include #include +#include "common/FieldData.h" #include "storage/PayloadStream.h" -#include "storage/FieldData.h" namespace milvus::storage { diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index f7749a6e7f..59b3c014ef 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -15,28 +15,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "storage/Util.h" #include + #include "arrow/array/builder_binary.h" #include "arrow/type_fwd.h" -#include "common/EasyAssert.h" -#include "common/Consts.h" #include "fmt/format.h" #include "log/Log.h" -#include "storage/ChunkManager.h" + +#include "common/Consts.h" +#include "common/EasyAssert.h" +#include "common/FieldData.h" +#include "common/FieldDataInterface.h" #ifdef AZURE_BUILD_DIR #include "storage/AzureChunkManager.h" #endif -#include "storage/FieldData.h" +#include "storage/ChunkManager.h" +#include "storage/DiskFileManagerImpl.h" #include "storage/InsertData.h" -#include "storage/FieldDataInterface.h" -#include "storage/ThreadPools.h" #include "storage/LocalChunkManager.h" +#include "storage/MemFileManagerImpl.h" #include "storage/MinioChunkManager.h" #include "storage/OpenDALChunkManager.h" -#include "storage/MemFileManagerImpl.h" -#include "storage/DiskFileManagerImpl.h" #include "storage/Types.h" +#include "storage/ThreadPools.h" +#include "storage/Util.h" namespace milvus::storage { @@ -727,18 +729,18 @@ GetByteSizeOfFieldDatas(const std::vector& field_datas) { return result; } -std::vector -CollectFieldDataChannel(storage::FieldDataChannelPtr& channel) { - std::vector result; - storage::FieldDataPtr field_data; +std::vector +CollectFieldDataChannel(FieldDataChannelPtr& channel) { + std::vector result; + FieldDataPtr field_data; while (channel->pop(field_data)) { result.push_back(field_data); } return result; } -storage::FieldDataPtr -MergeFieldData(std::vector& data_array) { +FieldDataPtr +MergeFieldData(std::vector& data_array) { if (data_array.size() == 0) { return nullptr; } diff --git a/internal/core/src/storage/Util.h b/internal/core/src/storage/Util.h index eba7d5d366..fbd5fcb541 100644 --- a/internal/core/src/storage/Util.h +++ b/internal/core/src/storage/Util.h @@ -20,15 +20,15 @@ #include #include -#include "storage/FieldData.h" +#include "common/FieldData.h" +#include "common/LoadInfo.h" +#include "knowhere/comp/index_param.h" +#include "parquet/schema.h" #include "storage/PayloadStream.h" #include "storage/FileManager.h" #include "storage/BinlogReader.h" #include "storage/ChunkManager.h" #include "storage/DataCodec.h" -#include "knowhere/comp/index_param.h" -#include "parquet/schema.h" -#include "common/LoadInfo.h" #include "storage/Types.h" #include "storage/space.h" @@ -161,10 +161,10 @@ CreateFieldData(const DataType& type, int64_t GetByteSizeOfFieldDatas(const std::vector& field_datas); -std::vector -CollectFieldDataChannel(storage::FieldDataChannelPtr& channel); +std::vector +CollectFieldDataChannel(FieldDataChannelPtr& channel); -storage::FieldDataPtr -MergeFieldData(std::vector& data_array); +FieldDataPtr +MergeFieldData(std::vector& data_array); } // namespace milvus::storage diff --git a/internal/core/src/storage/parquet_c.cpp b/internal/core/src/storage/parquet_c.cpp index caa7ca5057..ec7e8f1728 100644 --- a/internal/core/src/storage/parquet_c.cpp +++ b/internal/core/src/storage/parquet_c.cpp @@ -17,10 +17,10 @@ #include #include "common/EasyAssert.h" +#include "common/FieldData.h" #include "storage/parquet_c.h" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" -#include "storage/FieldData.h" #include "storage/Util.h" using Payload = milvus::storage::Payload; diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 72c3394826..2629262603 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -59,6 +59,7 @@ set(MILVUS_TEST_FILES test_chunk_cache.cpp test_binlog_index.cpp test_storage.cpp + test_exec.cpp ) if ( BUILD_DISK_ANN STREQUAL "ON" ) @@ -80,7 +81,7 @@ endif() if (DEFINED AZURE_BUILD_DIR) set(MILVUS_TEST_FILES ${MILVUS_TEST_FILES} - test_azure_chunk_manager.cpp + #test_azure_chunk_manager.cpp #need update aws-sdk-cpp, see more from https://github.com/aws/aws-sdk-cpp/issues/2119 #test_remote_chunk_manager.cpp ) @@ -129,6 +130,7 @@ target_link_libraries(all_tests milvus_indexbuilder pthread milvus_common + milvus_exec ) install(TARGETS all_tests DESTINATION unittest) diff --git a/internal/core/unittest/bench/bench_search.cpp b/internal/core/unittest/bench/bench_search.cpp index 9f63d61ed6..f40766ce82 100644 --- a/internal/core/unittest/bench/bench_search.cpp +++ b/internal/core/unittest/bench/bench_search.cpp @@ -31,7 +31,7 @@ const auto schema = []() { return schema; }(); -const auto plan = [] { +const auto search_plan = [] { const char* raw_plan = R"(vector_anns: < field_id: 100 query_info: < @@ -50,8 +50,8 @@ const auto plan = [] { auto ph_group = [] { auto num_queries = 10; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, 1024); - auto ph_group = - ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto ph_group = ParsePlaceholderGroup(search_plan.get(), + ph_group_raw.SerializeAsString()); return ph_group; }(); @@ -91,7 +91,7 @@ Search_GrowingIndex(benchmark::State& state) { dataset_.raw_); for (auto _ : state) { - auto qr = segment->Search(plan.get(), ph_group.get()); + auto qr = segment->Search(search_plan.get(), ph_group.get()); } } @@ -124,7 +124,7 @@ Search_Sealed(benchmark::State& state) { segment->LoadIndex(info); } for (auto _ : state) { - auto qr = segment->Search(plan.get(), ph_group.get()); + auto qr = segment->Search(search_plan.get(), ph_group.get()); } } diff --git a/internal/core/unittest/test_always_true_expr.cpp b/internal/core/unittest/test_always_true_expr.cpp index d1228a10b5..cc01d4b30f 100644 --- a/internal/core/unittest/test_always_true_expr.cpp +++ b/internal/core/unittest/test_always_true_expr.cpp @@ -20,6 +20,8 @@ #include "query/generated/ExecExprVisitor.h" #include "segcore/SegmentGrowingImpl.h" #include "test_utils/DataGen.h" +#include "expr/ITypeExpr.h" +#include "plan/PlanNode.h" TEST(Expr, AlwaysTrue) { using namespace milvus; @@ -48,10 +50,12 @@ TEST(Expr, AlwaysTrue) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); - auto expr = CreateAlwaysTrueExpr(); - auto final = visitor.call_child(*expr); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + auto expr = std::make_shared(); + BitsetType final; + std::shared_ptr plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_array_expr.cpp b/internal/core/unittest/test_array_expr.cpp index 89798a4417..7461060138 100644 --- a/internal/core/unittest/test_array_expr.cpp +++ b/internal/core/unittest/test_array_expr.cpp @@ -27,6 +27,8 @@ #include "simdjson/padded_string.h" #include "test_utils/DataGen.h" #include "index/IndexFactory.h" +#include "expr/ITypeExpr.h" +#include "plan/PlanNode.h" TEST(Expr, TestArrayRange) { using namespace milvus; @@ -595,8 +597,7 @@ TEST(Expr, TestArrayRange) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, array_type, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -604,7 +605,9 @@ TEST(Expr, TestArrayRange) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -712,8 +715,7 @@ TEST(Expr, TestArrayEqual) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -721,7 +723,9 @@ TEST(Expr, TestArrayEqual) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -890,11 +894,10 @@ TEST(Expr, TestArrayContains) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true, true}, {}}, - {{false, false}, {}}}; + {{false, false}, {}}}; for (auto testcase : bool_testcases) { auto check = [&](const std::vector& values) { @@ -906,15 +909,22 @@ TEST(Expr, TestArrayContains) { } return false; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(bool_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_Contains, - proto::plan::GenericValue::ValCase::kBoolVal); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_bool_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(bool_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_Contains, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -929,7 +939,7 @@ TEST(Expr, TestArrayContains) { for (int j = 0; j < array.length(); ++j) { res.push_back(array.get_data(j)); } - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, check(res)) << "@" << i; } } @@ -952,15 +962,23 @@ TEST(Expr, TestArrayContains) { } return false; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(double_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_Contains, - proto::plan::GenericValue::ValCase::kFloatVal); + + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_float_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(double_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_Contains, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -989,15 +1007,22 @@ TEST(Expr, TestArrayContains) { } return false; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(float_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_Contains, - proto::plan::GenericValue::ValCase::kFloatVal); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_float_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(float_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_Contains, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1035,15 +1060,23 @@ TEST(Expr, TestArrayContains) { } return true; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(int_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kInt64Val); + + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_int64_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(int_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1072,15 +1105,23 @@ TEST(Expr, TestArrayContains) { } return true; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(long_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kInt64Val); + + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_int64_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(long_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1116,15 +1157,23 @@ TEST(Expr, TestArrayContains) { } return true; }; - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(string_array_fid, DataType::ARRAY), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kStringVal); + + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue gen_val; + gen_val.set_string_val(val); + values.push_back(gen_val); + } auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + expr::ColumnInfo(string_array_fid, DataType::ARRAY), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1195,8 +1244,7 @@ TEST(Expr, TestArrayBinaryArith) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector >)", "int", - [](milvus::Array& array) { - return array.length() == 10; - }}, + [](milvus::Array& array) { return array.length() == 10; }}, {R"(binary_arith_op_eval_range_expr: < column_info: < field_id: 101 @@ -1667,9 +1713,7 @@ TEST(Expr, TestArrayBinaryArith) { value: >)", "int", - [](milvus::Array& array) { - return array.length() != 8; - }}, + [](milvus::Array& array) { return array.length() != 8; }}, }; std::string raw_plan_tmp = R"(vector_anns: < @@ -1692,7 +1736,9 @@ TEST(Expr, TestArrayBinaryArith) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1743,8 +1789,7 @@ TEST(Expr, TestArrayStringMatch) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> prefix_testcases{ {OpType::PrefixMatch, @@ -1771,14 +1816,18 @@ TEST(Expr, TestArrayStringMatch) { }; //vector_anns: op:PrefixMatch value: > > query_info:<> placeholder_tag:"$0" > for (auto& testcase : prefix_testcases) { - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(string_array_fid, DataType::ARRAY, testcase.nested_path), - testcase.op_type, - testcase.value, - proto::plan::GenericValue::ValCase::kStringVal); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + proto::plan::GenericValue value; + value.set_string_val(testcase.value); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + string_array_fid, DataType::ARRAY, testcase.nested_path), + testcase.op_type, + value); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -1844,10 +1893,9 @@ TEST(Expr, TestArrayInTerm) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - std::vector>> testcases = { @@ -1860,11 +1908,11 @@ TEST(Expr, TestArrayInTerm) { > values: values: values: >)", - "long", - [](milvus::Array& array) { - auto val = array.get_data(0); - return val == 1 || val ==2 || val == 3; - }}, + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == 1 || val == 2 || val == 3; + }}, {R"(term_expr: < column_info: < field_id: 101 @@ -1874,9 +1922,7 @@ TEST(Expr, TestArrayInTerm) { > >)", "long", - [](milvus::Array& array) { - return false; - }}, + [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 102 @@ -1900,9 +1946,7 @@ TEST(Expr, TestArrayInTerm) { > >)", "bool", - [](milvus::Array& array) { - return false; - }}, + [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 103 @@ -1926,9 +1970,7 @@ TEST(Expr, TestArrayInTerm) { > >)", "float", - [](milvus::Array& array) { - return false; - }}, + [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 104 @@ -1952,9 +1994,7 @@ TEST(Expr, TestArrayInTerm) { > >)", "string", - [](milvus::Array& array) { - return false; - }}, + [](milvus::Array& array) { return false; }}, {R"(term_expr: < column_info: < field_id: 104 @@ -1995,7 +2035,9 @@ TEST(Expr, TestArrayInTerm) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2036,8 +2078,7 @@ TEST(Expr, TestTermInArray) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); struct TermTestCases { std::vector values; @@ -2070,14 +2111,22 @@ TEST(Expr, TestTermInArray) { }; for (auto& testcase : testcases) { - RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(long_array_fid, DataType::ARRAY, testcase.nested_path), - testcase.values, - proto::plan::GenericValue::ValCase::kInt64Val, - true); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + std::vector values; + for (auto& v : testcase.values) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.emplace_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + long_array_fid, DataType::ARRAY, testcase.nested_path), + values, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) diff --git a/internal/core/unittest/test_binlog_index.cpp b/internal/core/unittest/test_binlog_index.cpp index d96b78776e..4d4c3faf23 100644 --- a/internal/core/unittest/test_binlog_index.cpp +++ b/internal/core/unittest/test_binlog_index.cpp @@ -110,25 +110,22 @@ class BinlogIndexTest : public ::testing::TestWithParam { LoadFieldDataInfo row_id_info; FieldMeta row_id_field_meta( FieldName("RowID"), RowFieldID, DataType::INT64); - auto field_data = std::make_shared>( - DataType::INT64); + auto field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.row_ids_.data(), data_n); - auto field_data_info = - FieldDataInfo{RowFieldID.get(), - data_n, - std::vector{field_data}}; + auto field_data_info = FieldDataInfo{ + RowFieldID.get(), data_n, std::vector{field_data}}; segment->LoadFieldData(RowFieldID, field_data_info); // load ts LoadFieldDataInfo ts_info; FieldMeta ts_field_meta( FieldName("Timestamp"), TimestampFieldID, DataType::INT64); - field_data = std::make_shared>( - DataType::INT64); + field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.timestamps_.data(), data_n); - field_data_info = - FieldDataInfo{TimestampFieldID.get(), - data_n, - std::vector{field_data}}; + field_data_info = FieldDataInfo{TimestampFieldID.get(), + data_n, + std::vector{field_data}}; segment->LoadFieldData(TimestampFieldID, field_data_info); } @@ -138,8 +135,8 @@ class BinlogIndexTest : public ::testing::TestWithParam { size_t data_n = 10000; size_t data_d = 128; size_t topk = 10; - milvus::storage::FieldDataPtr vec_field_data = nullptr; - milvus::segcore::SegmentSealedPtr segment = nullptr; + milvus::FieldDataPtr vec_field_data = nullptr; + milvus::segcore::SegmentSealedUPtr segment = nullptr; milvus::FieldId vec_field_id; std::shared_ptr vec_data; }; @@ -159,10 +156,8 @@ TEST_P(BinlogIndexTest, Accuracy) { segcore_config.set_enable_interim_segment_index(true); segcore_config.set_nprobe(32); // 1. load field data, and build binlog index for binlog data - auto field_data_info = - FieldDataInfo{vec_field_id.get(), - data_n, - std::vector{vec_field_data}}; + auto field_data_info = FieldDataInfo{ + vec_field_id.get(), data_n, std::vector{vec_field_data}}; segment->LoadFieldData(vec_field_id, field_data_info); //assert segment has been built binlog index EXPECT_TRUE(segment->HasIndex(vec_field_id)); @@ -249,10 +244,8 @@ TEST_P(BinlogIndexTest, DisableInterimIndex) { LoadOtherFields(); SegcoreSetEnableTempSegmentIndex(false); - auto field_data_info = - FieldDataInfo{vec_field_id.get(), - data_n, - std::vector{vec_field_data}}; + auto field_data_info = FieldDataInfo{ + vec_field_id.get(), data_n, std::vector{vec_field_data}}; segment->LoadFieldData(vec_field_id, field_data_info); EXPECT_FALSE(segment->HasIndex(vec_field_id)); @@ -296,10 +289,8 @@ TEST_P(BinlogIndexTest, LoadBingLogWihIDMAP) { segment = CreateSealedSegment(schema, collection_index_meta); LoadOtherFields(); - auto field_data_info = - FieldDataInfo{vec_field_id.get(), - data_n, - std::vector{vec_field_data}}; + auto field_data_info = FieldDataInfo{ + vec_field_id.get(), data_n, std::vector{vec_field_data}}; segment->LoadFieldData(vec_field_id, field_data_info); EXPECT_FALSE(segment->HasIndex(vec_field_id)); @@ -314,10 +305,8 @@ TEST_P(BinlogIndexTest, LoadBinlogWithoutIndexMeta) { segment = CreateSealedSegment(schema, collection_index_meta); SegcoreSetEnableTempSegmentIndex(true); - auto field_data_info = - FieldDataInfo{vec_field_id.get(), - data_n, - std::vector{vec_field_data}}; + auto field_data_info = FieldDataInfo{ + vec_field_id.get(), data_n, std::vector{vec_field_data}}; segment->LoadFieldData(vec_field_id, field_data_info); EXPECT_FALSE(segment->HasIndex(vec_field_id)); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 1fb0014422..8465b002be 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -39,6 +39,9 @@ #include "test_utils/indexbuilder_test_utils.h" #include "test_utils/storage_test_utils.h" #include "query/generated/ExecExprVisitor.h" +#include "expr/ITypeExpr.h" +#include "plan/PlanNode.h" +#include "exec/expression/Expr.h" namespace chrono = std::chrono; @@ -465,16 +468,21 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { ASSERT_EQ(del_res.error_code, Success); // retrieve pks = {1} - std::vector retrive_pks = {1}; + std::vector retrive_pks; + { + proto::plan::GenericValue value; + value.set_int64_val(1); + retrive_pks.push_back(value); + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_pks, - proto::plan::GenericValue::kInt64Val); + retrive_pks); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; @@ -490,13 +498,17 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { DeleteRetrieveResult(&retrieve_result); // retrieve pks = {2} - retrive_pks = {2}; - term_expr = std::make_unique>( - milvus::query::ColumnInfo( + { + proto::plan::GenericValue value; + value.set_int64_val(2); + retrive_pks.push_back(value); + } + term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_pks, - proto::plan::GenericValue::kInt64Val); - plan->plan_node_->predicate_ = std::move(term_expr); + retrive_pks); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); suc = query_result->ParseFromArray(retrieve_result.proto_blob, @@ -567,16 +579,22 @@ TEST(CApiTest, MultiDeleteSealedSegment) { ASSERT_EQ(del_res.error_code, Success); // retrieve pks = {1} - std::vector retrive_pks = {1}; + std::vector retrive_pks; + { + proto::plan::GenericValue value; + value.set_int64_val(1); + retrive_pks.push_back(value); + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_pks, - proto::plan::GenericValue::kInt64Val); + retrive_pks); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; @@ -592,13 +610,17 @@ TEST(CApiTest, MultiDeleteSealedSegment) { DeleteRetrieveResult(&retrieve_result); // retrieve pks = {2} - retrive_pks = {2}; - term_expr = std::make_unique>( - milvus::query::ColumnInfo( + { + proto::plan::GenericValue value; + value.set_int64_val(2); + retrive_pks.push_back(value); + } + term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_pks, - proto::plan::GenericValue::kInt64Val); - plan->plan_node_->predicate_ = std::move(term_expr); + retrive_pks); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); res = CRetrieve(segment, plan.get(), {}, max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); suc = query_result->ParseFromArray(retrieve_result.proto_blob, @@ -674,16 +696,24 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) { ASSERT_EQ(res.error_code, Success); // create retrieve plan pks in {1, 2, 3} - std::vector retrive_row_ids = {1, 2, 3}; + std::vector retrive_row_ids; + { + for (auto v : {1, 2, 3}) { + proto::plan::GenericValue val; + val.set_int64_val(v); + retrive_row_ids.push_back(val); + } + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_row_ids, - proto::plan::GenericValue::kInt64Val); + retrive_row_ids); + plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -747,17 +777,24 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) { auto sealed_segment = dynamic_cast(segment_interface); SealedLoadFieldData(dataset, *sealed_segment); + std::vector retrive_row_ids; // create retrieve plan pks in {1, 2, 3} - std::vector retrive_row_ids = {1, 2, 3}; + { + for (auto v : {1, 2, 3}) { + proto::plan::GenericValue val; + val.set_int64_val(v); + retrive_row_ids.push_back(val); + } + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_row_ids, - proto::plan::GenericValue::kInt64Val); + retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -851,16 +888,23 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) { ASSERT_EQ(del_res.error_code, Success); // create retrieve plan pks in {1, 2, 3}, timestamp = 9 - std::vector retrive_row_ids = {1, 2, 3}; + std::vector retrive_row_ids; + { + for (auto v : {1, 2, 3}) { + proto::plan::GenericValue val; + val.set_int64_val(v); + retrive_row_ids.push_back(val); + } + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_row_ids, - proto::plan::GenericValue::kInt64Val); + retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -941,16 +985,23 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) { ASSERT_EQ(del_res.error_code, Success); // create retrieve plan pks in {1, 2, 3}, timestamp = 9 - std::vector retrive_row_ids = {1, 2, 3}; + std::vector retrive_row_ids; + { + for (auto v : {1, 2, 3}) { + proto::plan::GenericValue val; + val.set_int64_val(v); + retrive_row_ids.push_back(val); + } + } auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - retrive_row_ids, - proto::plan::GenericValue::kInt64Val); + retrive_row_ids); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -1127,15 +1178,21 @@ TEST(CApiTest, RetrieveTestWithExpr) { ASSERT_EQ(ins_res.error_code, Success); // create retrieve plan "age in [0]" - std::vector values(1, 0); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + std::vector values; + { + for (auto v : {1, 0}) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + } + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( FieldId(101), DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); - + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; @@ -4017,13 +4074,16 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) { // create retrieve plan auto plan = std::make_unique(*schema); plan->plan_node_ = std::make_unique(); - std::vector retrive_row_ids = {age64_col[0]}; - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + std::vector retrive_row_ids; + proto::plan::GenericValue val; + val.set_int64_val(age64_col[0]); + retrive_row_ids.push_back(val); + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( i64_fid, DataType::INT64, std::vector()), - retrive_row_ids, - proto::plan::GenericValue::kInt64Val); - plan->plan_node_->predicate_ = std::move(term_expr); + retrive_row_ids); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_field_ids; // retrieve value diff --git a/internal/core/unittest/test_exec.cpp b/internal/core/unittest/test_exec.cpp new file mode 100644 index 0000000000..dc9372a228 --- /dev/null +++ b/internal/core/unittest/test_exec.cpp @@ -0,0 +1,357 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include +#include + +#include "query/Expr.h" +#include "query/PlanImpl.h" +#include "query/PlanNode.h" +#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/generated/ExprVisitor.h" +#include "query/generated/ShowPlanNodeVisitor.h" +#include "segcore/SegmentSealed.h" +#include "test_utils/AssertUtils.h" +#include "test_utils/DataGen.h" +#include "plan/PlanNode.h" +#include "exec/Task.h" +#include "exec/QueryContext.h" +#include "expr/ITypeExpr.h" +#include "exec/expression/Expr.h" + +using namespace milvus; +using namespace milvus::exec; +using namespace milvus::query; +using namespace milvus::segcore; + +class TaskTest : public testing::Test { + protected: + void + SetUp() override { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); + field_map_.insert({"bool", bool_fid}); + auto bool_1_fid = schema->AddDebugField("bool1", DataType::BOOL); + field_map_.insert({"bool1", bool_1_fid}); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + field_map_.insert({"int8", int8_fid}); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + field_map_.insert({"int81", int8_1_fid}); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + field_map_.insert({"int16", int16_fid}); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + field_map_.insert({"int161", int16_1_fid}); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + field_map_.insert({"int32", int32_fid}); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + field_map_.insert({"int321", int32_1_fid}); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + field_map_.insert({"int64", int64_fid}); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + field_map_.insert({"int641", int64_1_fid}); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + field_map_.insert({"float", float_fid}); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + field_map_.insert({"float1", float_1_fid}); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + field_map_.insert({"double", double_fid}); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + field_map_.insert({"double1", double_1_fid}); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + field_map_.insert({"string1", str1_fid}); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + field_map_.insert({"string2", str2_fid}); + auto str3_fid = schema->AddDebugField("string3", DataType::VARCHAR); + field_map_.insert({"string3", str3_fid}); + schema->set_primary_field_id(str1_fid); + + auto segment = CreateSealedSegment(schema); + size_t N = 1000000; + num_rows_ = N; + auto raw_data = DataGen(schema, N); + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + segment->LoadFieldData(FieldId(field_id), info); + } + segment_ = SegmentSealedSPtr(segment.release()); + } + + void + TearDown() override { + } + + public: + SegmentSealedSPtr segment_; + std::map field_map_; + int64_t num_rows_{0}; +}; + +TEST_F(TaskTest, UnaryExpr) { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(-1); + auto logical_expr = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::LessThan, + value); + std::vector sources; + auto filter_node = std::make_shared( + "plannode id 1", logical_expr, sources); + auto plan = plan::PlanFragment(filter_node); + auto query_context = std::make_shared( + "test1", + segment_.get(), + MAX_TIMESTAMP, + std::make_shared( + std::unordered_map{})); + + auto start = std::chrono::steady_clock::now(); + auto task = Task::Create("task_unary_expr", plan, 0, query_context); + int64_t num_rows = 0; + int i = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + num_rows += result->size(); + } + auto cost = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost << "us" << std::endl; + EXPECT_EQ(num_rows, num_rows_); +} + +TEST_F(TaskTest, LogicalExpr) { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(-1); + auto left = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::LessThan, + value); + auto right = std::make_shared( + expr::ColumnInfo(field_map_["int64"], DataType::INT64), + proto::plan::OpType::LessThan, + value); + + auto top = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); + std::vector sources; + auto filter_node = std::make_shared( + "plannode id 1", top, sources); + auto plan = plan::PlanFragment(filter_node); + auto query_context = std::make_shared( + "test1", + segment_.get(), + MAX_TIMESTAMP, + std::make_shared( + std::unordered_map{})); + + auto start = std::chrono::steady_clock::now(); + auto task = + Task::Create("task_logical_binary_expr", plan, 0, query_context); + int64_t num_rows = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + num_rows += result->size(); + } + auto cost = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + std::cout << "cost: " << cost << "us" << std::endl; + EXPECT_EQ(num_rows, num_rows_); +} + +TEST(CompileInputs, and) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto seg = CreateSealedSegment(schema); + proto::plan::GenericValue val; + val.set_int64_val(10); + // expr: (int64_fid < 10 and int64_fid < 10) and (int64_fid < 10 and int64_fid < 10) + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr5 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr3, expr6); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, seg.get(), MAX_TIMESTAMP); + auto exprs = milvus::exec::CompileInputs(expr7, query_context.get(), {}); + EXPECT_EQ(exprs.size(), 4); + for (int i = 0; i < exprs.size(); ++i) { + std::cout << exprs[i]->get_name() << std::endl; + EXPECT_STREQ(exprs[i]->get_name().c_str(), "PhyUnaryRangeFilterExpr"); + } +} + +TEST(CompileInputs, or_with_and) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto seg = CreateSealedSegment(schema); + proto::plan::GenericValue val; + val.set_int64_val(10); + { + // expr: (int64_fid < 10 and int64_fid < 10) or (int64_fid < 10 and int64_fid < 10) + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr5 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, seg.get(), MAX_TIMESTAMP); + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr6); + auto exprs = + milvus::exec::CompileInputs(expr7, query_context.get(), {}); + EXPECT_EQ(exprs.size(), 2); + for (int i = 0; i < exprs.size(); ++i) { + std::cout << exprs[i]->get_name() << std::endl; + EXPECT_STREQ(exprs[i]->get_name().c_str(), "and"); + } + } + { + // expr: (int64_fid < 10 or int64_fid < 10) or (int64_fid < 10 and int64_fid < 10) + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr5 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, seg.get(), MAX_TIMESTAMP); + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr3, expr6); + auto exprs = + milvus::exec::CompileInputs(expr7, query_context.get(), {}); + std::cout << exprs.size() << std::endl; + EXPECT_EQ(exprs.size(), 3); + for (int i = 0; i < exprs.size() - 1; ++i) { + std::cout << exprs[i]->get_name() << std::endl; + EXPECT_STREQ(exprs[i]->get_name().c_str(), + "PhyUnaryRangeFilterExpr"); + } + EXPECT_STREQ(exprs[2]->get_name().c_str(), "and"); + } + { + // expr: (int64_fid < 10 or int64_fid < 10) and (int64_fid < 10 and int64_fid < 10) + auto expr1 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr2 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr3 = std::make_shared( + expr::LogicalBinaryExpr::OpType::Or, expr1, expr2); + auto expr4 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr5 = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + auto expr6 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr1, expr2); + auto query_context = std::make_shared( + DEAFULT_QUERY_ID, seg.get(), MAX_TIMESTAMP); + auto expr7 = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, expr3, expr6); + auto exprs = + milvus::exec::CompileInputs(expr7, query_context.get(), {}); + std::cout << exprs.size() << std::endl; + EXPECT_EQ(exprs.size(), 3); + EXPECT_STREQ(exprs[0]->get_name().c_str(), "or"); + for (int i = 1; i < exprs.size(); ++i) { + std::cout << exprs[i]->get_name() << std::endl; + EXPECT_STREQ(exprs[i]->get_name().c_str(), + "PhyUnaryRangeFilterExpr"); + } + } +} \ No newline at end of file diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index eacc3970e6..af6c0b46b0 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -32,6 +32,8 @@ #include "segcore/segment_c.h" #include "test_utils/DataGen.h" #include "index/IndexFactory.h" +#include "exec/expression/Expr.h" +#include "exec/Task.h" TEST(Expr, Range) { SUCCEED(); @@ -500,8 +502,7 @@ TEST(Expr, TestRange) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -509,7 +510,10 @@ TEST(Expr, TestRange) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -569,8 +573,7 @@ TEST(Expr, TestBinaryRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](int64_t value) { int64_t lower = testcase.lower, upper = testcase.upper; @@ -584,14 +587,22 @@ TEST(Expr, TestBinaryRangeJSON) { }; auto pointer = milvus::Json::pointer(testcase.nested_path); RetrievePlanNode plan; - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - proto::plan::GenericValue::ValCase::kInt64Val, + milvus::proto::plan::GenericValue lower_val; + lower_val.set_int64_val(testcase.lower); + milvus::proto::plan::GenericValue upper_val; + upper_val.set_int64_val(testcase.upper); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + lower_val, + upper_val, testcase.lower_inclusive, - testcase.upper_inclusive, - testcase.lower, - testcase.upper); - auto final = visitor.call_child(*plan.predicate_.value()); + testcase.upper_inclusive); + BitsetType final; + plan.filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode( + plan.filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -657,15 +668,19 @@ TEST(Expr, TestExistsJson) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](bool value) { return value; }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path)); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = + std::make_shared(milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path)); + BitsetType final; + plan.filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode( + plan.filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -678,6 +693,37 @@ TEST(Expr, TestExistsJson) { } } +template +T +GetValueFromProto(const milvus::proto::plan::GenericValue& value_proto) { + if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kBoolVal); + return static_cast(value_proto.bool_val()); + } else if constexpr (std::is_integral_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kInt64Val); + return static_cast(value_proto.int64_val()); + } else if constexpr (std::is_floating_point_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kFloatVal); + return static_cast(value_proto.float_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kStringVal); + return static_cast(value_proto.string_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == + milvus::proto::plan::GenericValue::kArrayVal); + return static_cast(value_proto.array_val()); + } else if constexpr (std::is_same_v) { + return static_cast(value_proto); + } else { + PanicInfo(milvus::ErrorCode::UnexpectedError, + "unsupported generic value type"); + } +}; + TEST(Expr, TestUnaryRangeJson) { using namespace milvus; using namespace milvus::query; @@ -722,8 +768,7 @@ TEST(Expr, TestUnaryRangeJson) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector ops{ OpType::Equal, OpType::NotEqual, @@ -766,14 +811,19 @@ TEST(Expr, TestUnaryRangeJson) { } } - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + proto::plan::GenericValue value; + value.set_int64_val(testcase.val); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), op, - testcase.val, - proto::plan::GenericValue::ValCase::kInt64Val); - auto final = visitor.call_child(*plan.predicate_.value()); + value); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); + EXPECT_EQ(final.size(), N * num_iters); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -798,26 +848,26 @@ TEST(Expr, TestUnaryRangeJson) { } struct TestArrayCase { - proto::plan::Array val; + proto::plan::GenericValue val; std::vector nested_path; }; - proto::plan::Array arr; - arr.set_same_type(true); + proto::plan::GenericValue value; + auto* arr = value.mutable_array_val(); + arr->set_same_type(true); proto::plan::GenericValue int_val1; int_val1.set_int64_val(int64_t(1)); - arr.add_array()->CopyFrom(int_val1); + arr->add_array()->CopyFrom(int_val1); proto::plan::GenericValue int_val2; int_val2.set_int64_val(int64_t(2)); - arr.add_array()->CopyFrom(int_val2); + arr->add_array()->CopyFrom(int_val2); proto::plan::GenericValue int_val3; int_val3.set_int64_val(int64_t(3)); - arr.add_array()->CopyFrom(int_val3); - - std::vector array_cases = {{arr, {"array"}}}; + arr->add_array()->CopyFrom(int_val3); + std::vector array_cases = {{value, {"array"}}}; for (const auto& testcase : array_cases) { auto check = [&](OpType op) { if (testcase.nested_path[0] == "array" && op == OpType::Equal) { @@ -826,21 +876,22 @@ TEST(Expr, TestUnaryRangeJson) { return false; }; for (auto& op : ops) { - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - op, - testcase.val, - proto::plan::GenericValue::ValCase::kArrayVal); - auto final = visitor.call_child(*plan.predicate_.value()); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + op, + testcase.val); + BitsetType final; + auto plan = std::make_shared( + DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; auto ref = check(op); - ASSERT_EQ(ans, ref); + ASSERT_EQ(ans, ref) << "@" << i << "op" << op; } } } @@ -886,21 +937,28 @@ TEST(Expr, TestTermJson) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](int64_t value) { std::unordered_set term_set(testcase.term.begin(), testcase.term.end()); return term_set.find(value) != term_set.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - proto::plan::GenericValue::ValCase::kInt64Val); - auto final = visitor.call_child(*plan.predicate_.value()); + std::vector values; + for (const auto& val : testcase.term) { + proto::plan::GenericValue value; + value.set_int64_val(val); + values.push_back(value); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1016,8 +1074,7 @@ TEST(Expr, TestTerm) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -1025,7 +1082,9 @@ TEST(Expr, TestTerm) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1131,8 +1190,7 @@ TEST(Expr, TestCompare) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -1140,7 +1198,9 @@ TEST(Expr, TestCompare) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -1227,7 +1287,7 @@ TEST(Expr, TestCompareWithScalarIndex) { load_index_info.index = std::move(age64_index); seg->LoadIndex(load_index_info); - ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % clause % @@ -1238,7 +1298,9 @@ TEST(Expr, TestCompareWithScalarIndex) { auto plan = CreateSearchPlanByExpr( *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), final); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -1294,116 +1356,113 @@ TEST(Expr, TestCompareExpr) { seg->LoadFieldData(FieldId(field_id), info); } - ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP); - auto build_expr = [&](enum DataType type) -> std::shared_ptr { + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + auto build_expr = [&](enum DataType type) -> expr::TypedExprPtr { switch (type) { case DataType::BOOL: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::BOOL; - compare_expr->left_field_id_ = bool_fid; - - compare_expr->right_data_type_ = DataType::BOOL; - compare_expr->right_field_id_ = bool_1_fid; + auto compare_expr = std::make_shared( + bool_fid, + bool_1_fid, + DataType::BOOL, + DataType::BOOL, + proto::plan::OpType::LessThan); return compare_expr; } case DataType::INT8: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::INT8; - compare_expr->left_field_id_ = int8_fid; - - compare_expr->right_data_type_ = DataType::INT8; - compare_expr->right_field_id_ = int8_1_fid; + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); return compare_expr; } case DataType::INT16: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::INT16; - compare_expr->left_field_id_ = int16_fid; - - compare_expr->right_data_type_ = DataType::INT16; - compare_expr->right_field_id_ = int16_1_fid; + auto compare_expr = + std::make_shared(int16_fid, + int16_1_fid, + DataType::INT16, + DataType::INT16, + OpType::LessThan); return compare_expr; } case DataType::INT32: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::INT32; - compare_expr->left_field_id_ = int32_fid; - - compare_expr->right_data_type_ = DataType::INT32; - compare_expr->right_field_id_ = int32_1_fid; + auto compare_expr = + std::make_shared(int32_fid, + int32_1_fid, + DataType::INT32, + DataType::INT32, + OpType::LessThan); return compare_expr; } case DataType::INT64: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::INT64; - compare_expr->left_field_id_ = int64_fid; - - compare_expr->right_data_type_ = DataType::INT64; - compare_expr->right_field_id_ = int64_1_fid; + auto compare_expr = + std::make_shared(int64_fid, + int64_1_fid, + DataType::INT64, + DataType::INT64, + OpType::LessThan); return compare_expr; } case DataType::FLOAT: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::FLOAT; - compare_expr->left_field_id_ = float_fid; - - compare_expr->right_data_type_ = DataType::FLOAT; - compare_expr->right_field_id_ = float_1_fid; + auto compare_expr = + std::make_shared(float_fid, + float_1_fid, + DataType::FLOAT, + DataType::FLOAT, + OpType::LessThan); return compare_expr; } case DataType::DOUBLE: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::DOUBLE; - compare_expr->left_field_id_ = double_fid; - - compare_expr->right_data_type_ = DataType::DOUBLE; - compare_expr->right_field_id_ = double_1_fid; + auto compare_expr = + std::make_shared(double_fid, + double_1_fid, + DataType::DOUBLE, + DataType::DOUBLE, + OpType::LessThan); return compare_expr; } case DataType::VARCHAR: { - auto compare_expr = std::make_shared(); - compare_expr->op_type_ = OpType::LessThan; - - compare_expr->left_data_type_ = DataType::VARCHAR; - compare_expr->left_field_id_ = str2_fid; - - compare_expr->right_data_type_ = DataType::VARCHAR; - compare_expr->right_field_id_ = str3_fid; + auto compare_expr = + std::make_shared(str2_fid, + str3_fid, + DataType::VARCHAR, + DataType::VARCHAR, + OpType::LessThan); return compare_expr; } default: - return std::make_shared(); + return std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); } }; std::cout << "start compare test" << std::endl; auto expr = build_expr(DataType::BOOL); - auto final = visitor.call_child(*expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::INT8); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::INT16); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::INT32); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::INT64); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::FLOAT); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); expr = build_expr(DataType::DOUBLE); - final = visitor.call_child(*expr); + plan = std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg.get(), final); std::cout << "end compare test" << std::endl; } @@ -1671,6 +1730,789 @@ TEST(Expr, TestExprs) { // test_case(500); } +TEST(Expr, TestSealedSegmentGetBatchSize) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + proto::plan::GenericValue val; + val.set_int64_val(10); + auto expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto plan_node = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + std::vector test_batch_size = { + 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; + for (const auto& batch_size : test_batch_size) { + EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; + auto plan = plan::PlanFragment(plan_node); + auto query_context = std::make_shared( + "query id", seg.get(), MAX_TIMESTAMP); + + auto task = + milvus::exec::Task::Create("task_expr", plan, 0, query_context); + auto last_num = N % batch_size; + auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; + int iter = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + if (++iter != iter_num) { + EXPECT_EQ(childrens[0]->size(), batch_size); + } else { + EXPECT_EQ(childrens[0]->size(), last_num); + } + } + } +} + +TEST(Expr, TestGrowingSegmentGetBatchSize) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000000; + auto raw_data = DataGen(schema, N); + seg->PreInsert(N); + seg->Insert(0, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + + proto::plan::GenericValue val; + val.set_int64_val(10); + auto expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto plan_node = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + std::vector test_batch_size = { + 8192, 10240, 20480, 30720, 40960, 102400, 204800, 307200}; + + for (const auto& batch_size : test_batch_size) { + EXEC_EVAL_EXPR_BATCH_SIZE = batch_size; + auto plan = plan::PlanFragment(plan_node); + auto query_context = std::make_shared( + "query id", seg.get(), MAX_TIMESTAMP); + + auto task = + milvus::exec::Task::Create("task_expr", plan, 0, query_context); + auto last_num = N % batch_size; + auto iter_num = last_num == 0 ? N / batch_size : N / batch_size + 1; + int iter = 0; + for (;;) { + auto result = task->Next(); + if (!result) { + break; + } + auto childrens = result->childrens(); + if (++iter != iter_num) { + EXPECT_EQ(childrens[0]->size(), batch_size); + } else { + EXPECT_EQ(childrens[0]->size(), last_num); + } + } + } +} + +TEST(Expr, TestUnaryBenchTest) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST(Expr, TestBinaryRangeBenchTest) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue lower; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + lower.set_float_val(10); + } else { + lower.set_int64_val(10); + } + proto::plan::GenericValue upper; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + upper.set_float_val(45); + } else { + upper.set_int64_val(45); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + lower, + upper, + true, + true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10.0 << "us" << std::endl; + } +} + +TEST(Expr, TestLogicalUnaryBenchTest) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(10); + } else { + val.set_int64_val(10); + } + auto child_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::GreaterThan, + val); + auto expr = std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST(Expr, TestBinaryLogicalBenchTest) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(-1000000); + } else { + val.set_int64_val(-1000000); + } + proto::plan::GenericValue val1; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val1.set_float_val(-100); + } else { + val1.set_int64_val(-100); + } + auto child1_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::LessThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::NotEqual, + val1); + auto expr = std::make_shared( + expr::LogicalBinaryExpr::OpType::And, child1_expr, child2_expr); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST(Expr, TestBinaryArithOpEvalRangeBenchExpr) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector> test_cases = { + {int8_fid, DataType::INT8}, + {int16_fid, DataType::INT16}, + {int32_fid, DataType::INT32}, + {int64_fid, DataType::INT64}, + {float_fid, DataType::FLOAT}, + {double_fid, DataType::DOUBLE}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.second) << std::endl; + proto::plan::GenericValue val; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + val.set_float_val(100); + } else { + val.set_int64_val(100); + } + proto::plan::GenericValue right; + if (pair.second == DataType::FLOAT || pair.second == DataType::DOUBLE) { + right.set_float_val(10); + } else { + right.set_int64_val(10); + } + auto expr = std::make_shared( + expr::ColumnInfo(pair.first, pair.second), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 50; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 50.0 << "us" << std::endl; + } +} + +TEST(Expr, TestCompareExprBenchTest) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto float_1_fid = schema->AddDebugField("float1", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + auto double_1_fid = schema->AddDebugField("double1", DataType::DOUBLE); + + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + std::vector< + std::pair, std::pair>> + test_cases = { + {{int8_fid, DataType::INT8}, {int8_1_fid, DataType::INT8}}, + {{int16_fid, DataType::INT16}, {int16_fid, DataType::INT16}}, + {{int32_fid, DataType::INT32}, {int32_1_fid, DataType::INT32}}, + {{int64_fid, DataType::INT64}, {int64_1_fid, DataType::INT64}}, + {{float_fid, DataType::FLOAT}, {float_1_fid, DataType::FLOAT}}, + {{double_fid, DataType::DOUBLE}, {double_1_fid, DataType::DOUBLE}}}; + + for (const auto& pair : test_cases) { + std::cout << "start test type:" << int(pair.first.second) << std::endl; + proto::plan::GenericValue lower; + auto expr = std::make_shared(pair.first.first, + pair.second.first, + pair.first.second, + pair.second.second, + OpType::LessThan); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + int64_t all_cost = 0; + for (int i = 0; i < 10; i++) { + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + all_cost += std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + } + std::cout << " cost: " << all_cost / 10 << "us" << std::endl; + } +} + +TEST(Expr, TestRefactorExprs) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + + enum ExprType { + UnaryRangeExpr = 0, + TermExprImpl = 1, + CompareExpr = 2, + LogicalUnaryExpr = 3, + BinaryRangeExpr = 4, + LogicalBinaryExpr = 5, + BinaryArithOpEvalRangeExpr = 6, + }; + + auto build_expr = [&](enum ExprType test_type, + int n) -> expr::TypedExprPtr { + switch (test_type) { + case UnaryRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + val); + } + case TermExprImpl: { + std::vector retrieve_ints; + // for (int i = 0; i < n; ++i) { + // retrieve_ints.push_back("xxxxxx" + std::to_string(i % 10)); + // } + // return std::make_shared>( + // ColumnInfo(str1_fid, DataType::VARCHAR), + // retrieve_ints, + // proto::plan::GenericValue::ValCase::kStringVal); + for (int i = 0; i < n; ++i) { + proto::plan::GenericValue val; + val.set_float_val(i); + retrieve_ints.push_back(val); + } + return std::make_shared( + expr::ColumnInfo(double_fid, DataType::DOUBLE), + retrieve_ints); + } + case CompareExpr: { + auto compare_expr = + std::make_shared(int8_fid, + int8_1_fid, + DataType::INT8, + DataType::INT8, + OpType::LessThan); + return compare_expr; + } + case BinaryRangeExpr: { + proto::plan::GenericValue lower; + lower.set_int64_val(10); + proto::plan::GenericValue upper; + upper.set_int64_val(45); + return std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + lower, + upper, + true, + true); + } + case LogicalUnaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + return std::make_shared( + expr::LogicalUnaryExpr::OpType::LogicalNot, child_expr); + } + case LogicalBinaryExpr: { + proto::plan::GenericValue val; + val.set_int64_val(10); + auto child1_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + auto child2_expr = std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::NotEqual, + val); + ; + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, + child1_expr, + child2_expr); + } + case BinaryArithOpEvalRangeExpr: { + proto::plan::GenericValue val; + val.set_int64_val(100); + proto::plan::GenericValue right; + right.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::Equal, + proto::plan::ArithOpType::Add, + val, + right); + } + default: { + proto::plan::GenericValue val; + val.set_int64_val(10); + return std::make_shared( + expr::ColumnInfo(int8_fid, DataType::INT8), + proto::plan::OpType::GreaterThan, + val); + } + } + }; + auto test_case = [&](int n) { + auto expr = build_expr(UnaryRangeExpr, n); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + visitor.ExecuteExprNode(plan, seg.get(), final); + std::cout << n << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + }; + test_case(3); + test_case(10); + test_case(20); + test_case(30); + test_case(50); + test_case(100); + test_case(200); + // test_case(500); +} + TEST(Expr, TestCompareWithScalarIndexMaris) { using namespace milvus; using namespace milvus::query; @@ -1748,7 +2590,7 @@ TEST(Expr, TestCompareWithScalarIndexMaris) { load_index_info.index = std::move(str2_index); seg->LoadIndex(load_index_info); - ExecExprVisitor visitor(*seg, seg->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto dsl_string = boost::format(serialized_expr_plan) % vec_fid.get() % clause % str1_fid.get() % str2_fid.get(); @@ -1757,7 +2599,9 @@ TEST(Expr, TestCompareWithScalarIndexMaris) { auto plan = CreateSearchPlanByExpr( *schema, binary_plan.data(), binary_plan.size()); // std::cout << ShowPlanNodeVisitor().call_child(*plan->plan_node_) << std::endl; - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg.get(), final); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -2081,8 +2925,7 @@ TEST(Expr, TestBinaryArithOpEvalRange) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto [clause, ref_func, dtype] : testcases) { auto loc = raw_plan_tmp.find("@@@@@"); auto raw_plan = raw_plan_tmp; @@ -2107,7 +2950,9 @@ TEST(Expr, TestBinaryArithOpEvalRange) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2115,7 +2960,8 @@ TEST(Expr, TestBinaryArithOpEvalRange) { if (dtype == DataType::INT8) { auto val = age8_col[i]; auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; + ASSERT_EQ(ans, ref) + << clause << "@" << i << "!!" << val << std::endl; } else if (dtype == DataType::INT16) { auto val = age16_col[i]; auto ref = ref_func(val); @@ -2189,8 +3035,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](int64_t value) { if (testcase.op == OpType::Equal) { @@ -2198,17 +3043,22 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { } return value + testcase.right_operand != testcase.value; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - proto::plan::GenericValue::ValCase::kInt64Val, - ArithOpType::Add, - testcase.right_operand, - testcase.op, - testcase.value); - auto final = visitor.call_child(*plan.predicate_.value()); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2225,7 +3075,8 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { .template at(pointer) .value(); auto ref = check(val); - ASSERT_EQ(ans, ref) << testcase.value << " " << val; + ASSERT_EQ(ans, ref) << testcase.value << " " << val << " " + << testcase.op << " " << i; } } } @@ -2277,8 +3128,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); for (auto testcase : testcases) { auto check = [&](double value) { if (testcase.op == OpType::Equal) { @@ -2286,17 +3136,22 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { } return value + testcase.right_operand != testcase.value; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - proto::plan::GenericValue::ValCase::kFloatVal, - ArithOpType::Add, - testcase.right_operand, - testcase.op, - testcase.value); - auto final = visitor.call_child(*plan.predicate_.value()); + proto::plan::GenericValue value; + value.set_float_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_float_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::Add, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2306,7 +3161,8 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { .template at(pointer) .value(); auto ref = check(val); - ASSERT_EQ(ans, ref) << testcase.value << " " << val; + ASSERT_EQ(ans, ref) + << testcase.value << " " << val << " " << testcase.op; } } @@ -2322,17 +3178,22 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { } return value != testcase.value; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - proto::plan::GenericValue::ValCase::kInt64Val, - ArithOpType::ArrayLength, - testcase.right_operand, - testcase.op, - testcase.value); - auto final = visitor.call_child(*plan.predicate_.value()); + proto::plan::GenericValue value; + value.set_int64_val(testcase.value); + proto::plan::GenericValue right_operand; + right_operand.set_int64_val(testcase.right_operand); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + testcase.op, + ArithOpType::ArrayLength, + value, + right_operand); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + visitor.ExecuteExprNode(plan, seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2596,8 +3457,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) { seg->LoadIndex(load_index_info); auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -2633,7 +3493,9 @@ TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) { auto plan = CreateSearchPlanByExpr( *schema, binary_plan.data(), binary_plan.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N); for (int i = 0; i < N; ++i) { @@ -2787,8 +3649,7 @@ TEST(Expr, TestUnaryRangeWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -2833,7 +3694,9 @@ TEST(Expr, TestUnaryRangeWithJSON) { auto plan = CreateSearchPlanByExpr( *schema, unary_plan.data(), unary_plan.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -2965,8 +3828,7 @@ TEST(Expr, TestTermWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -3011,7 +3873,9 @@ TEST(Expr, TestTermWithJSON) { auto plan = CreateSearchPlanByExpr( *schema, unary_plan.data(), unary_plan.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -3110,8 +3974,7 @@ TEST(Expr, TestExistsWithJSON) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); int offset = 0; for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); @@ -3163,7 +4026,9 @@ TEST(Expr, TestExistsWithJSON) { auto plan = CreateSearchPlanByExpr( *schema, unary_plan.data(), unary_plan.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -3236,8 +4101,7 @@ TEST(Expr, TestTermInFieldJson) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true}, {"bool"}}, {{false}, {"bool"}}}; @@ -3247,15 +4111,23 @@ TEST(Expr, TestTermInFieldJson) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - proto::plan::GenericValue::ValCase::kBoolVal, + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); // std::cout << "cost" // << std::chrono::duration_cast( // std::chrono::steady_clock::now() - start) @@ -3287,15 +4159,23 @@ TEST(Expr, TestTermInFieldJson) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - proto::plan::GenericValue::ValCase::kFloatVal, + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3327,15 +4207,23 @@ TEST(Expr, TestTermInFieldJson) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - proto::plan::GenericValue::ValCase::kInt64Val, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3367,15 +4255,23 @@ TEST(Expr, TestTermInFieldJson) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - proto::plan::GenericValue::ValCase::kStringVal, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + values, true); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3572,8 +4468,7 @@ TEST(Expr, TestJsonContainsAny) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true}, {"bool"}}, {{false}, {"bool"}}}; @@ -3583,16 +4478,24 @@ TEST(Expr, TestJsonContainsAny) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kBoolVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3624,16 +4527,24 @@ TEST(Expr, TestJsonContainsAny) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kFloatVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3665,16 +4576,24 @@ TEST(Expr, TestJsonContainsAny) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kInt64Val); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3706,16 +4625,24 @@ TEST(Expr, TestJsonContainsAny) { return std::find(values.begin(), values.end(), testcase.term[0]) != values.end(); }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kStringVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3765,8 +4692,7 @@ TEST(Expr, TestJsonContainsAll) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); std::vector> bool_testcases{{{true, true}, {"bool"}}, {{false, false}, {"bool"}}}; @@ -3781,16 +4707,24 @@ TEST(Expr, TestJsonContainsAll) { } return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto v : testcase.term) { + proto::plan::GenericValue val; + val.set_bool_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kBoolVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3829,16 +4763,24 @@ TEST(Expr, TestJsonContainsAll) { } return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_float_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kFloatVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3877,16 +4819,24 @@ TEST(Expr, TestJsonContainsAll) { } return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_int64_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kInt64Val); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3923,16 +4873,24 @@ TEST(Expr, TestJsonContainsAll) { } return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, + std::vector values; + for (auto& v : testcase.term) { + proto::plan::GenericValue val; + val.set_string_val(v); + values.push_back(val); + } + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kStringVal); + true, + values); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -3982,46 +4940,47 @@ TEST(Expr, TestJsonContainsArray) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - proto::plan::Array a; - a.set_same_type(false); + proto::plan::GenericValue generic_a; + auto* a = generic_a.mutable_array_val(); + a->set_same_type(false); for (int i = 0; i < 4; ++i) { if (i % 4 == 0) { proto::plan::GenericValue int_val; int_val.set_int64_val(int64_t(i)); - a.add_array()->CopyFrom(int_val); + a->add_array()->CopyFrom(int_val); } else if ((i - 1) % 4 == 0) { proto::plan::GenericValue bool_val; bool_val.set_bool_val(bool(i)); - a.add_array()->CopyFrom(bool_val); + a->add_array()->CopyFrom(bool_val); } else if ((i - 2) % 4 == 0) { proto::plan::GenericValue float_val; float_val.set_float_val(double(i)); - a.add_array()->CopyFrom(float_val); + a->add_array()->CopyFrom(float_val); } else if ((i - 3) % 4 == 0) { proto::plan::GenericValue string_val; string_val.set_string_val(std::to_string(i)); - a.add_array()->CopyFrom(string_val); + a->add_array()->CopyFrom(string_val); } } - proto::plan::Array b; - b.set_same_type(true); + proto::plan::GenericValue generic_b; + auto* b = generic_b.mutable_array_val(); + b->set_same_type(true); proto::plan::GenericValue int_val1; int_val1.set_int64_val(int64_t(1)); - b.add_array()->CopyFrom(int_val1); + b->add_array()->CopyFrom(int_val1); proto::plan::GenericValue int_val2; int_val2.set_int64_val(int64_t(2)); - b.add_array()->CopyFrom(int_val2); + b->add_array()->CopyFrom(int_val2); proto::plan::GenericValue int_val3; int_val3.set_int64_val(int64_t(3)); - b.add_array()->CopyFrom(int_val3); + b->add_array()->CopyFrom(int_val3); - std::vector> diff_testcases{{{a}, {"string"}}, - {{b}, {"array"}}}; + std::vector> diff_testcases{ + {{generic_a}, {"string"}}, {{generic_b}, {"array"}}}; for (auto& testcase : diff_testcases) { auto check = [&](const std::vector& values, int i) { @@ -4030,17 +4989,18 @@ TEST(Expr, TestJsonContainsArray) { } return false; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4062,17 +5022,18 @@ TEST(Expr, TestJsonContainsArray) { } return false; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4087,41 +5048,44 @@ TEST(Expr, TestJsonContainsArray) { } } - proto::plan::Array sub_arr1; - sub_arr1.set_same_type(true); + proto::plan::GenericValue g_sub_arr1; + auto* sub_arr1 = g_sub_arr1.mutable_array_val(); + sub_arr1->set_same_type(true); proto::plan::GenericValue int_val11; int_val11.set_int64_val(int64_t(1)); - sub_arr1.add_array()->CopyFrom(int_val11); + sub_arr1->add_array()->CopyFrom(int_val11); proto::plan::GenericValue int_val12; int_val12.set_int64_val(int64_t(2)); - sub_arr1.add_array()->CopyFrom(int_val12); + sub_arr1->add_array()->CopyFrom(int_val12); - proto::plan::Array sub_arr2; - sub_arr2.set_same_type(true); + proto::plan::GenericValue g_sub_arr2; + auto* sub_arr2 = g_sub_arr2.mutable_array_val(); + sub_arr2->set_same_type(true); proto::plan::GenericValue int_val21; int_val21.set_int64_val(int64_t(3)); - sub_arr2.add_array()->CopyFrom(int_val21); + sub_arr2->add_array()->CopyFrom(int_val21); proto::plan::GenericValue int_val22; int_val22.set_int64_val(int64_t(4)); - sub_arr2.add_array()->CopyFrom(int_val22); - std::vector> diff_testcases2{ - {{sub_arr1, sub_arr2}, {"array2"}}}; + sub_arr2->add_array()->CopyFrom(int_val22); + std::vector> diff_testcases2{ + {{g_sub_arr1, g_sub_arr2}, {"array2"}}}; for (auto& testcase : diff_testcases2) { auto check = [&]() { return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4139,17 +5103,18 @@ TEST(Expr, TestJsonContainsArray) { auto check = [&](const std::vector& values, int i) { return true; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4164,43 +5129,46 @@ TEST(Expr, TestJsonContainsArray) { } } - proto::plan::Array sub_arr3; - sub_arr3.set_same_type(true); + proto::plan::GenericValue g_sub_arr3; + auto* sub_arr3 = g_sub_arr3.mutable_array_val(); + sub_arr3->set_same_type(true); proto::plan::GenericValue int_val31; int_val31.set_int64_val(int64_t(5)); - sub_arr3.add_array()->CopyFrom(int_val31); + sub_arr3->add_array()->CopyFrom(int_val31); proto::plan::GenericValue int_val32; int_val32.set_int64_val(int64_t(6)); - sub_arr3.add_array()->CopyFrom(int_val32); + sub_arr3->add_array()->CopyFrom(int_val32); - proto::plan::Array sub_arr4; - sub_arr4.set_same_type(true); + proto::plan::GenericValue g_sub_arr4; + auto* sub_arr4 = g_sub_arr4.mutable_array_val(); + sub_arr4->set_same_type(true); proto::plan::GenericValue int_val41; int_val41.set_int64_val(int64_t(7)); - sub_arr4.add_array()->CopyFrom(int_val41); + sub_arr4->add_array()->CopyFrom(int_val41); proto::plan::GenericValue int_val42; int_val42.set_int64_val(int64_t(8)); - sub_arr4.add_array()->CopyFrom(int_val42); - std::vector> diff_testcases3{ - {{sub_arr3, sub_arr4}, {"array2"}}}; + sub_arr4->add_array()->CopyFrom(int_val42); + std::vector> diff_testcases3{ + {{g_sub_arr3, g_sub_arr4}, {"array2"}}}; for (auto& testcase : diff_testcases3) { auto check = [&](const std::vector& values, int i) { return false; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4219,17 +5187,18 @@ TEST(Expr, TestJsonContainsArray) { auto check = [&](const std::vector& values, int i) { return false; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - true, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + true, + testcase.term); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4304,8 +5273,7 @@ TEST(Expr, TestJsonContainsDiffTypeArray) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); proto::plan::GenericValue int_value; int_value.set_int64_val(1); @@ -4329,17 +5297,18 @@ TEST(Expr, TestJsonContainsDiffTypeArray) { for (auto& testcase : diff_testcases) { auto check = [&]() { return testcase.res; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - false, - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4355,17 +5324,18 @@ TEST(Expr, TestJsonContainsDiffTypeArray) { for (auto& testcase : diff_testcases) { auto check = [&]() { return false; }; - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - false, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::kArrayVal); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4409,8 +5379,7 @@ TEST(Expr, TestJsonContainsDiffType) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); proto::plan::GenericValue int_val; int_val.set_int64_val(int64_t(3)); @@ -4440,17 +5409,18 @@ TEST(Expr, TestJsonContainsDiffType) { }; for (auto& testcase : diff_testcases) { - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - false, - proto::plan::JSONContainsExpr_JSONOp_ContainsAny, - proto::plan::GenericValue::ValCase::VAL_NOT_SET); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) @@ -4465,17 +5435,18 @@ TEST(Expr, TestJsonContainsDiffType) { } for (auto& testcase : diff_testcases) { - RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); - plan.predicate_ = - std::make_unique>( - ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), - testcase.term, - false, - proto::plan::JSONContainsExpr_JSONOp_ContainsAll, - proto::plan::GenericValue::ValCase::VAL_NOT_SET); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + json_fid, DataType::JSON, testcase.nested_path), + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + false, + testcase.term); + BitsetType final; + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); auto start = std::chrono::steady_clock::now(); - auto final = visitor.call_child(*plan.predicate_.value()); + visitor.ExecuteExprNode(plan, seg_promote, final); std::cout << "cost" << std::chrono::duration_cast( std::chrono::steady_clock::now() - start) diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index 4069b8f376..f41f6f9ceb 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -309,17 +309,21 @@ TEST(Float16, RetrieveEmpty) { auto segment = CreateSealedSegment(schema); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(choose(i)); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(choose(i)); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; diff --git a/internal/core/unittest/test_integer_overflow.cpp b/internal/core/unittest/test_integer_overflow.cpp index 0ab984efd9..7404af7b72 100644 --- a/internal/core/unittest/test_integer_overflow.cpp +++ b/internal/core/unittest/test_integer_overflow.cpp @@ -615,8 +615,6 @@ binary_arith_op_eval_range_expr: < } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (auto [clause, ref_func] : testcases) { auto loc = raw_plan_tmp.find("@@@@"); auto raw_plan = raw_plan_tmp; @@ -624,7 +622,10 @@ binary_arith_op_eval_range_expr: < auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_retrieve.cpp b/internal/core/unittest/test_retrieve.cpp index bac9e76b4c..2115b6086a 100644 --- a/internal/core/unittest/test_retrieve.cpp +++ b/internal/core/unittest/test_retrieve.cpp @@ -17,6 +17,8 @@ #include "query/ExprImpl.h" #include "segcore/ScalarIndex.h" #include "test_utils/DataGen.h" +#include "exec/expression/Expr.h" +#include "plan/PlanNode.h" using namespace milvus; using namespace milvus::segcore; @@ -72,17 +74,21 @@ TEST(Retrieve, AutoID) { auto i64_col = dataset.get_col(fid_64); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(i64_col[choose(i)]); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i64_col[choose(i)]); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_fields_id{fid_64, fid_vec}; plan->field_ids_ = target_fields_id; @@ -128,17 +134,21 @@ TEST(Retrieve, AutoID2) { auto i64_col = dataset.get_col(fid_64); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(i64_col[choose(i)]); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i64_col[choose(i)]); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -180,19 +190,25 @@ TEST(Retrieve, NotExist) { auto i64_col = dataset.get_col(fid_64); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(i64_col[choose(i)]); - values.emplace_back(choose2(i)); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val1; + val1.set_int64_val(i64_col[choose(i)]); + values.push_back(val1); + proto::plan::GenericValue val2; + val2.set_int64_val(choose2(i)); + values.push_back(val2); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -230,17 +246,21 @@ TEST(Retrieve, Empty) { auto segment = CreateSealedSegment(schema); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(choose(i)); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(choose(i)); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -270,14 +290,16 @@ TEST(Retrieve, Limit) { SealedLoadFieldData(dataset, *segment); auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + proto::plan::GenericValue unary_val; + unary_val.set_int64_val(0); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), OpType::GreaterEqual, - 0, - proto::plan::GenericValue::kInt64Val); + unary_val); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); // test query results exceed the limit size std::vector target_fields{TimestampFieldID, fid_64, fid_vec}; @@ -310,16 +332,17 @@ TEST(Retrieve, FillEntry) { auto dataset = DataGen(schema, N, 42); auto segment = CreateSealedSegment(schema); SealedLoadFieldData(dataset, *segment); - auto plan = std::make_unique(*schema); - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + proto::plan::GenericValue unary_val; + unary_val.set_int64_val(0); + auto expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), OpType::GreaterEqual, - 0, - proto::plan::GenericValue::kInt64Val); + unary_val); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, expr); // test query results exceed the limit size std::vector target_fields{TimestampFieldID, @@ -356,17 +379,22 @@ TEST(Retrieve, LargeTimestamp) { auto i64_col = dataset.get_col(fid_64); auto plan = std::make_unique(*schema); - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(i64_col[choose(i)]); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i64_col[choose(i)]); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); + ; plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_64, fid_vec}; plan->field_ids_ = target_offsets; @@ -420,17 +448,21 @@ TEST(Retrieve, Delete) { for (int i = 0; i < req_size; ++i) { timestamps.emplace_back(ts_col[choose(i)]); } - std::vector values; - for (int i = 0; i < req_size; ++i) { - values.emplace_back(i64_col[choose(i)]); + std::vector values; + { + for (int i = 0; i < req_size; ++i) { + proto::plan::GenericValue val; + val.set_int64_val(i64_col[choose(i)]); + values.push_back(val); + } } - auto term_expr = std::make_unique>( - milvus::query::ColumnInfo( + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( fid_64, DataType::INT64, std::vector()), - values, - proto::plan::GenericValue::kInt64Val); + values); plan->plan_node_ = std::make_unique(); - plan->plan_node_->predicate_ = std::move(term_expr); + plan->plan_node_->filter_plannode_ = + std::make_shared(DEFAULT_PLANNODE_ID, term_expr); std::vector target_offsets{fid_ts, fid_64, fid_vec}; plan->field_ids_ = target_offsets; diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index eb8359a851..24c1c12cae 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -683,22 +683,19 @@ TEST(Sealed, LoadScalarIndex) { FieldMeta row_id_field_meta( FieldName("RowID"), RowFieldID, DataType::INT64); auto field_data = - std::make_shared>(DataType::INT64); + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.row_ids_.data(), N); auto field_data_info = FieldDataInfo{ - RowFieldID.get(), N, std::vector{field_data}}; + RowFieldID.get(), N, std::vector{field_data}}; segment->LoadFieldData(RowFieldID, field_data_info); LoadFieldDataInfo ts_info; FieldMeta ts_field_meta( FieldName("Timestamp"), TimestampFieldID, DataType::INT64); - field_data = - std::make_shared>(DataType::INT64); + field_data = std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.timestamps_.data(), N); - field_data_info = - FieldDataInfo{TimestampFieldID.get(), - N, - std::vector{field_data}}; + field_data_info = FieldDataInfo{ + TimestampFieldID.get(), N, std::vector{field_data}}; segment->LoadFieldData(TimestampFieldID, field_data_info); LoadIndexInfo vec_info; @@ -965,8 +962,8 @@ TEST(Sealed, BF) { auto vec_data = GenRandomFloatVecs(N, dim); auto field_data = storage::CreateFieldData(DataType::VECTOR_FLOAT, dim); field_data->FillFieldData(vec_data.data(), N); - auto field_data_info = FieldDataInfo{ - fake_id.get(), N, std::vector{field_data}}; + auto field_data_info = + FieldDataInfo{fake_id.get(), N, std::vector{field_data}}; segment->LoadFieldData(fake_id, field_data_info); auto topK = 1; @@ -1019,8 +1016,8 @@ TEST(Sealed, BF_Overflow) { auto vec_data = GenMaxFloatVecs(N, dim); auto field_data = storage::CreateFieldData(DataType::VECTOR_FLOAT, dim); field_data->FillFieldData(vec_data.data(), N); - auto field_data_info = FieldDataInfo{ - fake_id.get(), N, std::vector{field_data}}; + auto field_data_info = + FieldDataInfo{fake_id.get(), N, std::vector{field_data}}; segment->LoadFieldData(fake_id, field_data_info); auto topK = 1; @@ -1545,10 +1542,8 @@ TEST(Sealed, SkipIndexSkipStringRange) { std::vector strings = {"e", "f", "g", "g", "j"}; auto string_field_data = storage::CreateFieldData(DataType::VARCHAR, 1, N); string_field_data->FillFieldData(strings.data(), N); - auto string_field_data_info = - FieldDataInfo{string_fid.get(), - N, - std::vector{string_field_data}}; + auto string_field_data_info = FieldDataInfo{ + string_fid.get(), N, std::vector{string_field_data}}; segment->LoadFieldData(string_fid, string_field_data_info); auto& skip_index = segment->GetSkipIndex(); ASSERT_TRUE(skip_index.CanSkipUnaryRange( @@ -1707,4 +1702,4 @@ TEST(Sealed, QueryAllFields) { dataset_size); EXPECT_EQ(float_array_result->scalars().array_data().data_size(), dataset_size); -} \ No newline at end of file +} diff --git a/internal/core/unittest/test_simd.cpp b/internal/core/unittest/test_simd.cpp index b8a3606394..edfc410c23 100644 --- a/internal/core/unittest/test_simd.cpp +++ b/internal/core/unittest/test_simd.cpp @@ -20,15 +20,7 @@ #include #include -#if defined(__x86_64__) -#include "simd/hook.h" -#include "simd/sse2.h" -#include "simd/sse4.h" -#include "simd/avx2.h" -#include "simd/avx512.h" - using namespace std; -using namespace milvus::simd; template using FixedVector = boost::container::vector; @@ -39,6 +31,15 @@ using FixedVector = boost::container::vector; << ::testing::UnitTest::GetInstance()->current_test_info()->name() \ << std::endl; +#if defined(__x86_64__) +#include "simd/hook.h" +#include "simd/ref.h" +#include "simd/sse2.h" +#include "simd/sse4.h" +#include "simd/avx2.h" +#include "simd/avx512.h" + +using namespace milvus::simd; TEST(GetBitSetBlock, base_test_sse) { FixedVector src; for (int i = 0; i < 64; ++i) { @@ -750,6 +751,469 @@ TEST(FindTermAVX512, double_type) { ASSERT_EQ(res, true); } +TEST(AllBooleanSSE2, function) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(false); + } + auto res = AllFalseSSE2(src.data(), src.size()); + EXPECT_EQ(res, true); + res = AllTrueSSE2(src.data(), src.size()); + EXPECT_EQ(res, false); + + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + res = AllFalseSSE2(src.data(), src.size()); + EXPECT_EQ(res, false); + res = AllTrueSSE2(src.data(), src.size()); + EXPECT_EQ(res, false); + + src.clear(); + for (int i = 0; i < 8192; ++i) { + src.push_back(true); + } + res = AllTrueSSE2(src.data(), src.size()); + EXPECT_EQ(res, true); +} + +TEST(AllBooleanSSE2, performance) { + FixedVector src; + + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + std::cout << "sse2" << std::endl; + for (int j = 0; j < 10; j++) { + auto start = std::chrono::system_clock::now(); + auto res = AllFalseSSE2(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + start = std::chrono::system_clock::now(); + res = AllTrueSSE2(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + std::cout << "avx2" << std::endl; + for (int j = 0; j < 10; j++) { + auto start = std::chrono::system_clock::now(); + auto res = AllFalseAVX2(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + start = std::chrono::system_clock::now(); + res = AllTrueAVX2(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + for (int j = 0; j < 10; j++) { + auto start = std::chrono::system_clock::now(); + auto res = AllFalseRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + start = std::chrono::system_clock::now(); + res = AllTrueRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + +TEST(InvertBool, function) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + InvertBoolSSE2(src.data(), src.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(src[i], (i % 2) != 0); + } + + src.clear(); + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 3 == 0 ? true : false); + } + InvertBoolSSE2(src.data(), src.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(src[i], (i % 3) != 0); + } +} + +TEST(InvertBool, performance) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + InvertBoolSSE2(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + InvertBoolRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + +TEST(LogicalBool, function) { + FixedVector left; + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + FixedVector right; + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 2 == 0 ? true : false); + } + AndBoolSSE2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], i % 2 == 0); + } + OrBoolSSE2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], i % 2 == 0); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + AndBoolSSE2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + OrBoolSSE2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + AndBoolAVX2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + OrBoolAVX2(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); + } +} + +TEST(LogicalBool, performance) { + FixedVector left; + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + FixedVector right; + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 2 == 0 ? true : false); + } + std::cout << "sse2" << std::endl; + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + AndBoolSSE2(left.data(), right.data(), left.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + std::cout << "avx2" << std::endl; + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + AndBoolAVX2(left.data(), right.data(), left.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + AndBoolRef(left.data(), right.data(), left.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + +#endif + +#if defined(__ARM_NEON) +#include "simd/ref.h" +#include "simd/neon.h" +using namespace milvus::simd; + +#include +#include + +void +print_uint8x16(uint8x16_t vec) { + uint8_t tmp[16]; + vst1q_u8(tmp, vec); + + std::cout << "Vector contents: "; + for (int i = 0; i < 16; ++i) { + std::cout << static_cast(tmp[i]) << " "; + } + std::cout << std::endl; +} + +void +print_uint8x8(uint8x8_t vec) { + uint8_t tmp[8]; + vst1_u8(tmp, vec); + + std::cout << "Vector contents: "; + for (int i = 0; i < 8; ++i) { + std::cout << static_cast(tmp[i]) << " "; + } + std::cout << std::endl; +} + +void +print_uint16x8(uint16x8_t vec) { + uint16_t tmp[8]; + vst1q_u16(tmp, vec); + + std::cout << "Vector contents: "; + for (int i = 0; i < 8; ++i) { + std::cout << static_cast(tmp[i]) << " "; + } + std::cout << std::endl; +} + +TEST(InvertBool, function) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + InvertBoolNEON(src.data(), src.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(src[i], (i % 2) != 0); + } + + src.clear(); + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 3 == 0 ? true : false); + } + InvertBoolNEON(src.data(), src.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(src[i], (i % 3) != 0); + } +} + +TEST(InvertBool, performance) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + InvertBoolNEON(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + InvertBoolRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + +TEST(LogicalBool, function) { + FixedVector left; + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + FixedVector right; + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 2 == 0 ? true : false); + } + AndBoolNEON(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], i % 2 == 0); + } + OrBoolNEON(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], i % 2 == 0); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + AndBoolNEON(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); + } + + left.clear(); + right.clear(); + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 5 == 0 ? true : false); + } + OrBoolNEON(left.data(), right.data(), right.size()); + for (int i = 0; i < 8192; ++i) { + EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); + } +} + +TEST(LogicalBool, performance) { + FixedVector left; + for (int i = 0; i < 8192; ++i) { + left.push_back(i % 2 == 0 ? true : false); + } + FixedVector right; + for (int i = 0; i < 8192; ++i) { + right.push_back(i % 2 == 0 ? true : false); + } + std::cout << "NEON" << std::endl; + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + AndBoolNEON(left.data(), right.data(), left.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + std::cout << "ref" << std::endl; + + for (int i = 0; i < 10; ++i) { + auto start = std::chrono::system_clock::now(); + AndBoolRef(left.data(), right.data(), left.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + +TEST(AllBooleanNeon, function) { + FixedVector src; + for (int i = 0; i < 8192; ++i) { + src.push_back(false); + } + auto res = AllFalseNEON(src.data(), src.size()); + EXPECT_EQ(res, true); + res = AllTrueNEON(src.data(), src.size()); + EXPECT_EQ(res, false); + + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + res = AllFalseNEON(src.data(), src.size()); + EXPECT_EQ(res, false); + res = AllTrueNEON(src.data(), src.size()); + EXPECT_EQ(res, false); + + src.clear(); + for (int i = 0; i < 8192; ++i) { + src.push_back(true); + } + res = AllTrueNEON(src.data(), src.size()); + EXPECT_EQ(res, true); +} + +TEST(AllBooleanNeon, performance) { + FixedVector src; + + for (int i = 0; i < 8192; ++i) { + src.push_back(i % 2 == 0 ? true : false); + } + std::cout << "NEON" << std::endl; + for (int j = 0; j < 10; j++) { + auto start = std::chrono::system_clock::now(); + auto res = AllFalseNEON(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + start = std::chrono::system_clock::now(); + res = AllTrueNEON(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } + + std::cout << "ref" << std::endl; + for (int j = 0; j < 10; j++) { + auto start = std::chrono::system_clock::now(); + auto res = AllFalseRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + start = std::chrono::system_clock::now(); + res = AllTrueRef(src.data(), src.size()); + std::cout << std::chrono::duration_cast( + std::chrono::system_clock::now() - start) + .count() + << std::endl; + } +} + #endif int diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index 32aa555d5f..842afcae52 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -324,12 +324,13 @@ TEST(StringExpr, Term) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (const auto& [_, term] : terms) { auto plan_proto = GenTermPlan(fvec_meta, str_meta, term); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -437,12 +438,13 @@ TEST(StringExpr, Compare) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (const auto& [op, ref_func] : testcases) { auto plan_proto = gen_compare_plan(op); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -533,12 +535,13 @@ TEST(StringExpr, UnaryRange) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (const auto& [op, value, ref_func] : testcases) { auto plan_proto = gen_unary_range_plan(op, value); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { @@ -645,14 +648,15 @@ TEST(StringExpr, BinaryRange) { } auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor( - *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); for (const auto& [lb_inclusive, ub_inclusive, lb, ub, ref_func] : testcases) { auto plan_proto = gen_binary_range_plan(lb_inclusive, ub_inclusive, lb, ub); auto plan = ProtoParser(*schema).CreatePlan(*plan_proto); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode( + plan->plan_node_->filter_plannode_.value(), seg_promote, final); EXPECT_EQ(final.size(), N * num_iters); for (int i = 0; i < N * num_iters; ++i) { diff --git a/internal/core/unittest/test_utils.cpp b/internal/core/unittest/test_utils.cpp index 60a31a20e1..a92facde32 100644 --- a/internal/core/unittest/test_utils.cpp +++ b/internal/core/unittest/test_utils.cpp @@ -16,10 +16,11 @@ #include #include "common/EasyAssert.h" +#include "common/Types.h" #include "common/Utils.h" +#include "common/Exception.h" #include "query/Utils.h" #include "test_utils/DataGen.h" -#include "common/Types.h" TEST(Util, StringMatch) { using namespace milvus; @@ -144,13 +145,16 @@ struct TmpFileWrapper { std::string filename; TmpFileWrapper(const std::string& _filename) : filename{_filename} { - fd = open( - filename.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR); + fd = open(filename.c_str(), + O_RDWR | O_CREAT | O_EXCL, + S_IRUSR | S_IWUSR | S_IXUSR); } TmpFileWrapper(const TmpFileWrapper&) = delete; TmpFileWrapper(TmpFileWrapper&&) = delete; - TmpFileWrapper& operator =(const TmpFileWrapper&) = delete; - TmpFileWrapper& operator =(TmpFileWrapper&&) = delete; + TmpFileWrapper& + operator=(const TmpFileWrapper&) = delete; + TmpFileWrapper& + operator=(TmpFileWrapper&&) = delete; ~TmpFileWrapper() { if (fd != -1) { close(fd); @@ -181,8 +185,8 @@ TEST(Util, read_from_fd) { tmp_file.fd, read_buf.get(), data_size * max_loop)); // On Linux, read() (and similar system calls) will transfer at most 0x7ffff000 (2,147,479,552) bytes once - EXPECT_THROW(milvus::index::ReadDataFromFD( - tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX), - milvus::SegcoreError); + EXPECT_THROW( + milvus::index::ReadDataFromFD( + tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX), + milvus::SegcoreError); } - diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index c6c540e62b..a2b363a989 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -730,12 +730,12 @@ SearchResultToJson(const SearchResult& sr) { return nlohmann::json{results}; }; -inline storage::FieldDataPtr +inline FieldDataPtr CreateFieldDataFromDataArray(ssize_t raw_count, const DataArray* data, const FieldMeta& field_meta) { int64_t dim = 1; - storage::FieldDataPtr field_data = nullptr; + FieldDataPtr field_data = nullptr; auto createFieldData = [&field_data, &raw_count](const void* raw_data, DataType data_type, @@ -846,23 +846,23 @@ SealedLoadFieldData(const GeneratedData& dataset, bool with_mmap = false) { auto row_count = dataset.row_ids_.size(); { - auto field_data = std::make_shared>( - DataType::INT64); + auto field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.row_ids_.data(), row_count); - auto field_data_info = FieldDataInfo( - RowFieldID.get(), - row_count, - std::vector{field_data}); + auto field_data_info = + FieldDataInfo(RowFieldID.get(), + row_count, + std::vector{field_data}); seg.LoadFieldData(RowFieldID, field_data_info); } { - auto field_data = std::make_shared>( - DataType::INT64); + auto field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.timestamps_.data(), row_count); - auto field_data_info = FieldDataInfo( - TimestampFieldID.get(), - row_count, - std::vector{field_data}); + auto field_data_info = + FieldDataInfo(TimestampFieldID.get(), + row_count, + std::vector{field_data}); seg.LoadFieldData(TimestampFieldID, field_data_info); } for (auto& iter : dataset.schema_->get_fields()) { diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index fd712ced5e..31e3b06d6b 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -25,11 +25,11 @@ #include "storage/ThreadPools.h" using milvus::DataType; +using milvus::storage::FieldDataMeta; +using milvus::FieldDataPtr; using milvus::FieldId; using milvus::segcore::GeneratedData; using milvus::storage::ChunkManagerPtr; -using milvus::storage::FieldDataMeta; -using milvus::storage::FieldDataPtr; using milvus::storage::InsertData; using milvus::storage::StorageConfig; @@ -75,15 +75,15 @@ PrepareInsertBinlog(int64_t collection_id, }; { - auto field_data = std::make_shared>( - DataType::INT64); + auto field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.row_ids_.data(), row_count); auto path = prefix + "/" + std::to_string(RowFieldID.get()); SaveFieldData(field_data, path, RowFieldID.get()); } { - auto field_data = std::make_shared>( - DataType::INT64); + auto field_data = + std::make_shared>(DataType::INT64); field_data->FillFieldData(dataset.timestamps_.data(), row_count); auto path = prefix + "/" + std::to_string(TimestampFieldID.get()); SaveFieldData(field_data, path, TimestampFieldID.get()); diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 76f705dbe3..fb12bdfdf9 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -221,6 +221,9 @@ func (node *QueryNode) InitSegcore() error { cCPUNum := C.int(hardware.GetCPUNum()) C.InitCpuNum(cCPUNum) + cExprBatchSize := C.int64_t(paramtable.Get().QueryNodeCfg.ExprEvalBatchSize.GetAsInt64()) + C.InitDefaultExprEvalBatchSize(cExprBatchSize) + localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole) initcore.InitLocalChunkManager(localDataRootPath) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index da4b7f5999..bb31b340db 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1717,6 +1717,8 @@ type queryNodeConfig struct { CGOPoolSizeRatio ParamItem `refreshable:"true"` EnableWorkerSQCostMetrics ParamItem `refreshable:"true"` + + ExprEvalBatchSize ParamItem `refreshable:"false"` } func (p *queryNodeConfig) init(base *BaseTable) { @@ -2104,6 +2106,15 @@ Max read concurrency must greater than or equal to 1, and less than or equal to Doc: "whether use worker's cost to measure delegator's workload", } p.EnableWorkerSQCostMetrics.Init(base.mgr) + + p.ExprEvalBatchSize = ParamItem{ + Key: "queryNode.segcore.exprEvalBatchSize", + Version: "2.3.4", + DefaultValue: "8192", + Doc: "expr eval batch size for getnext interface", + } + + p.ExprEvalBatchSize.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/scripts/core_build.sh b/scripts/core_build.sh index b65816b65d..73ac74bf5e 100755 --- a/scripts/core_build.sh +++ b/scripts/core_build.sh @@ -98,7 +98,7 @@ CUDA_ARCH="DEFAULT" EMBEDDED_MILVUS="OFF" BUILD_DISK_ANN="OFF" USE_ASAN="OFF" -USE_DYNAMIC_SIMD="OFF" +USE_DYNAMIC_SIMD="ON" INDEX_ENGINE="KNOWHERE" while getopts "p:d:t:s:f:n:i:y:a:x:ulrcghzmebZ" arg; do