From adf2be650620bc5541dc8fd99a09d5e8b789b825 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Fri, 28 Nov 2025 17:43:11 +0800 Subject: [PATCH] enhance: batch cp optimizations to 2.6 (#45869) issue: #44452 pr: #45829 pr: #45328 pr: #45307 pr: #45008 pr: #44634 --------- Signed-off-by: zhagnlu Signed-off-by: Buqian Zheng Co-authored-by: zhagnlu Co-authored-by: luzhang --- Makefile | 15 + cmd/tools/exprparser/main.go | 115 +++++ internal/core/src/common/FieldDataInterface.h | 3 - internal/core/src/common/FieldMeta.cpp | 59 +++ internal/core/src/common/FieldMeta.h | 4 + internal/core/src/common/Schema.cpp | 23 + internal/core/src/common/Schema.h | 28 ++ internal/core/src/common/Types.h | 61 +++ .../src/exec/expression/BinaryRangeExpr.cpp | 50 ++ .../src/exec/expression/BinaryRangeExpr.h | 4 + .../core/src/exec/expression/TermExpr.cpp | 6 +- internal/core/src/plan/PlanNode.h | 89 ---- internal/core/src/query/PlanProto.cpp | 7 +- .../src/segcore/ChunkedSegmentSealedImpl.cpp | 368 ++++---------- .../src/segcore/ChunkedSegmentSealedImpl.h | 449 +++++++++++++++++- internal/core/src/segcore/InsertRecord.h | 27 ++ internal/core/src/segcore/SegmentGrowing.h | 11 + .../core/src/segcore/SegmentGrowingImpl.cpp | 21 +- .../core/src/segcore/SegmentGrowingImpl.h | 14 +- internal/core/src/segcore/SegmentInterface.h | 26 +- internal/core/unittest/CMakeLists.txt | 27 ++ internal/core/unittest/bench/CMakeLists.txt | 2 +- .../core/unittest/bench/bench_search_pk.cpp | 27 -- internal/core/unittest/test_sealed.cpp | 25 - .../unittest/test_utils/storage_test_utils.h | 6 +- 25 files changed, 980 insertions(+), 487 deletions(-) create mode 100644 cmd/tools/exprparser/main.go delete mode 100644 internal/core/unittest/bench/bench_search_pk.cpp diff --git a/Makefile b/Makefile index ae13823f1c..1e7947555d 100644 --- a/Makefile +++ b/Makefile @@ -386,6 +386,21 @@ run-test-cpp: @echo $(PWD)/scripts/run_cpp_unittest.sh arg=${filter} @(env bash $(PWD)/scripts/run_cpp_unittest.sh arg=${filter}) +# tool for benchmark +exprparser-tool: + @echo "Building exprparser helper ..." + @source $(PWD)/scripts/setenv.sh && \ + mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH}" -o $(INSTALL_PATH)/exprparser $(PWD)/cmd/tools/exprparser/main.go 1>/dev/null + +# Build unittest with external scalar-benchmark enabled +scalar-bench: generated-proto exprparser-tool + @echo "Building Milvus cpp unittest with scalar-benchmark ... " + @(export CMAKE_EXTRA_ARGS="-DENABLE_SCALAR_BENCH=ON"; env bash $(PWD)/scripts/core_build.sh -t ${mode} -a ${use_asan} -u -n ${use_disk_index} -y ${use_dynamic_simd} ${AZURE_OPTION} -x ${index_engine} -o ${use_opendal} -f $(tantivy_features)) + +scalar-bench-ui: + @echo "Starting scalar-benchmark ui ... " + @(cd cmake_build/unittest/scalar-benchmark-src/ui && ./serve_ui_dev.sh) # Run code coverage. codecov: codecov-go codecov-cpp diff --git a/cmd/tools/exprparser/main.go b/cmd/tools/exprparser/main.go new file mode 100644 index 0000000000..b79c0af053 --- /dev/null +++ b/cmd/tools/exprparser/main.go @@ -0,0 +1,115 @@ +package main + +import ( + "bufio" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "google.golang.org/protobuf/proto" + + schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + _ "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +type parseRequest struct { + ID string `json:"id"` + Op string `json:"op"` + SchemaB64 string `json:"schema_b64"` + Expr string `json:"expr"` + Options struct { + IsCount bool `json:"is_count"` + Limit int64 `json:"limit"` + } `json:"options"` +} + +type parseResponse struct { + ID string `json:"id"` + OK bool `json:"ok"` + PlanB64 string `json:"plan_b64,omitempty"` + Error string `json:"error,omitempty"` +} + +func handle(line string) parseResponse { + line = strings.TrimSpace(line) + if line == "" { + return parseResponse{ID: "", OK: false, Error: "empty line"} + } + + var req parseRequest + if err := json.Unmarshal([]byte(line), &req); err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("invalid json: %v", err)} + } + if req.Op != "parse_expr" { + return parseResponse{ID: req.ID, OK: false, Error: "unsupported op"} + } + if req.SchemaB64 == "" { + return parseResponse{ID: req.ID, OK: false, Error: "missing schema_b64"} + } + if req.Expr == "" { + return parseResponse{ID: req.ID, OK: false, Error: "missing expr"} + } + + schemaBytes, err := base64.StdEncoding.DecodeString(req.SchemaB64) + if err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("decode schema_b64 failed: %v", err)} + } + var schema schemapb.CollectionSchema + if err := proto.Unmarshal(schemaBytes, &schema); err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("unmarshal schema failed: %v", err)} + } + + helper, err := typeutil.CreateSchemaHelper(&schema) + if err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("schema helper error: %v", err)} + } + + planNode, err := planparserv2.CreateRetrievePlan(helper, req.Expr, nil) + if err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("parse error: %v", err)} + } + + // Apply options if provided + if q := planNode.GetQuery(); q != nil { + q.IsCount = req.Options.IsCount + if req.Options.Limit > 0 { + q.Limit = req.Options.Limit + } + } + + planBytes, err := proto.Marshal(planNode) + if err != nil { + return parseResponse{ID: req.ID, OK: false, Error: fmt.Sprintf("marshal plan failed: %v", err)} + } + return parseResponse{ID: req.ID, OK: true, PlanB64: base64.StdEncoding.EncodeToString(planBytes)} +} + +func writeResp(w *bufio.Writer, resp parseResponse) { + b, _ := json.Marshal(resp) + _, _ = w.Write(b) + _ = w.WriteByte('\n') + _ = w.Flush() +} + +func main() { + in := bufio.NewScanner(os.Stdin) + buf := make([]byte, 0, 1024*1024) + in.Buffer(buf, 16*1024*1024) + w := bufio.NewWriter(os.Stdout) + + for { + if !in.Scan() { + if err := in.Err(); err != nil && err != io.EOF { + writeResp(w, parseResponse{ID: "", OK: false, Error: fmt.Sprintf("scan error: %v", err)}) + } + break + } + resp := handle(in.Text()) + writeResp(w, resp) + } +} diff --git a/internal/core/src/common/FieldDataInterface.h b/internal/core/src/common/FieldDataInterface.h index cd6328333a..54886e7ab1 100644 --- a/internal/core/src/common/FieldDataInterface.h +++ b/internal/core/src/common/FieldDataInterface.h @@ -40,9 +40,6 @@ #include "common/TypeTraits.h" namespace milvus { - -using DataType = milvus::DataType; - class FieldDataBase { public: explicit FieldDataBase(DataType data_type, bool nullable) diff --git a/internal/core/src/common/FieldMeta.cpp b/internal/core/src/common/FieldMeta.cpp index 5c82aaa243..c1e96799a4 100644 --- a/internal/core/src/common/FieldMeta.cpp +++ b/internal/core/src/common/FieldMeta.cpp @@ -68,6 +68,65 @@ FieldMeta::get_analyzer_params() const { return ParseTokenizerParams(params); } +milvus::proto::schema::FieldSchema +FieldMeta::ToProto() const { + milvus::proto::schema::FieldSchema proto; + proto.set_fieldid(id_.get()); + proto.set_name(name_.get()); + proto.set_data_type(ToProtoDataType(type_)); + proto.set_nullable(nullable_); + + if (has_default_value()) { + *proto.mutable_default_value() = *default_value_; + } + + if (element_type_ != DataType::NONE) { + proto.set_element_type(ToProtoDataType(element_type_)); + } + + auto add_type_param = [&proto](const std::string& key, + const std::string& value) { + auto* param = proto.add_type_params(); + param->set_key(key); + param->set_value(value); + }; + auto add_index_param = [&proto](const std::string& key, + const std::string& value) { + auto* param = proto.add_index_params(); + param->set_key(key); + param->set_value(value); + }; + + if (type_ == DataType::VECTOR_ARRAY) { + add_type_param("dim", std::to_string(get_dim())); + if (auto metric = get_metric_type(); metric.has_value()) { + add_index_param("metric_type", metric.value()); + } + } else if (IsVectorDataType(type_)) { + if (!IsSparseFloatVectorDataType(type_)) { + add_type_param("dim", std::to_string(get_dim())); + } + if (auto metric = get_metric_type(); metric.has_value()) { + add_index_param("metric_type", metric.value()); + } + } else if (IsStringDataType(type_)) { + std::map params; + if (string_info_.has_value()) { + params = string_info_->params; + } + params[MAX_LENGTH] = std::to_string(get_max_len()); + params["enable_match"] = enable_match() ? "true" : "false"; + params["enable_analyzer"] = enable_analyzer() ? "true" : "false"; + for (const auto& [key, value] : params) { + add_type_param(key, value); + } + } else if (IsArrayDataType(type_)) { + // element_type already populated above + } + + return proto; +} + FieldMeta FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) { auto field_id = FieldId(schema_proto.fieldid()); diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 873356677e..62fa9fe7d5 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -251,6 +252,9 @@ class FieldMeta { return default_value_; } + milvus::proto::schema::FieldSchema + ToProto() const; + size_t get_sizeof() const { AssertInfo(!IsSparseFloatVectorDataType(type_), diff --git a/internal/core/src/common/Schema.cpp b/internal/core/src/common/Schema.cpp index c221efb629..3b830f1370 100644 --- a/internal/core/src/common/Schema.cpp +++ b/internal/core/src/common/Schema.cpp @@ -111,6 +111,29 @@ Schema::ConvertToArrowSchema() const { return arrow::schema(arrow_fields); } +proto::schema::CollectionSchema +Schema::ToProto() const { + proto::schema::CollectionSchema schema_proto; + schema_proto.set_enable_dynamic_field(dynamic_field_id_opt_.has_value()); + + for (const auto& field_id : field_ids_) { + const auto& meta = fields_.at(field_id); + auto* field_proto = schema_proto.add_fields(); + *field_proto = meta.ToProto(); + + if (primary_field_id_opt_.has_value() && + field_id == primary_field_id_opt_.value()) { + field_proto->set_is_primary_key(true); + } + if (dynamic_field_id_opt_.has_value() && + field_id == dynamic_field_id_opt_.value()) { + field_proto->set_is_dynamic(true); + } + } + + return schema_proto; +} + std::unique_ptr> Schema::AbsentFields(Schema& old_schema) const { std::vector result; diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index dd999c1f40..dd7e2685e2 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -127,6 +127,31 @@ class Schema { return field_id; } + // string type + FieldId + AddDebugVarcharField(const FieldName& name, + DataType data_type, + int64_t max_length, + bool nullable, + bool enable_match, + bool enable_analyzer, + std::map& params, + std::optional default_value) { + auto field_id = FieldId(debug_id); + debug_id++; + auto field_meta = FieldMeta(name, + field_id, + data_type, + max_length, + nullable, + enable_match, + enable_analyzer, + params, + std::move(default_value)); + this->AddField(std::move(field_meta)); + return field_id; + } + // scalar type void AddField(const FieldName& name, @@ -294,6 +319,9 @@ class Schema { const ArrowSchemaPtr ConvertToArrowSchema() const; + proto::schema::CollectionSchema + ToProto() const; + void UpdateLoadFields(const std::vector& field_ids) { load_fields_.clear(); diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 8de4340953..cc67ba7037 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -156,6 +156,67 @@ GetDataTypeSize(DataType data_type, int dim = 1) { } } +// Convert internal DataType to proto schema DataType +inline proto::schema::DataType +ToProtoDataType(DataType data_type) { + switch (data_type) { + case DataType::NONE: + return proto::schema::DataType::None; + case DataType::BOOL: + return proto::schema::DataType::Bool; + case DataType::INT8: + return proto::schema::DataType::Int8; + case DataType::INT16: + return proto::schema::DataType::Int16; + case DataType::INT32: + return proto::schema::DataType::Int32; + case DataType::INT64: + return proto::schema::DataType::Int64; + + case DataType::FLOAT: + return proto::schema::DataType::Float; + case DataType::DOUBLE: + return proto::schema::DataType::Double; + + case DataType::STRING: + return proto::schema::DataType::String; + case DataType::VARCHAR: + return proto::schema::DataType::VarChar; + case DataType::ARRAY: + return proto::schema::DataType::Array; + case DataType::JSON: + return proto::schema::DataType::JSON; + case DataType::TEXT: + return proto::schema::DataType::Text; + case DataType::TIMESTAMPTZ: + return proto::schema::DataType::Timestamptz; + + case DataType::VECTOR_BINARY: + return proto::schema::DataType::BinaryVector; + case DataType::VECTOR_FLOAT: + return proto::schema::DataType::FloatVector; + case DataType::VECTOR_FLOAT16: + return proto::schema::DataType::Float16Vector; + case DataType::VECTOR_BFLOAT16: + return proto::schema::DataType::BFloat16Vector; + case DataType::VECTOR_SPARSE_U32_F32: + return proto::schema::DataType::SparseFloatVector; + case DataType::VECTOR_INT8: + return proto::schema::DataType::Int8Vector; + case DataType::VECTOR_ARRAY: + return proto::schema::DataType::ArrayOfVector; + + // Internal-only or unsupported mappings + case DataType::ROW: + default: + ThrowInfo( + DataTypeInvalid, + fmt::format( + "failed to convert to proto data type, invalid type {}", + data_type)); + } +} + inline std::shared_ptr GetArrowDataType(DataType data_type, int dim = 1) { switch (data_type) { diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index f1bb32265c..a100326c03 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -163,6 +163,15 @@ PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { template VectorPtr PhyBinaryRangeFilterExpr::ExecRangeVisitorImpl(EvalCtx& context) { + if (!has_offset_input_ && is_pk_field_ && + segment_->type() == SegmentType::Sealed) { + if (pk_type_ == DataType::VARCHAR) { + return ExecRangeVisitorImplForPk(context); + } else { + return ExecRangeVisitorImplForPk(context); + } + } + if (SegmentExpr::CanUseIndex() && !has_offset_input_) { return ExecRangeVisitorImplForIndex(); } else { @@ -860,5 +869,46 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray(EvalCtx& context) { return res_vec; } +template +VectorPtr +PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForPk(EvalCtx& context) { + typedef std:: + conditional_t, std::string, T> + PkInnerType; + + if (!arg_inited_) { + lower_arg_.SetValue(expr_->lower_val_); + upper_arg_.SetValue(expr_->upper_val_); + arg_inited_ = true; + } + + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + if (cached_index_chunk_id_ != 0) { + cached_index_chunk_id_ = 0; + cached_index_chunk_res_ = std::make_shared(active_count_); + auto cache_view = cached_index_chunk_res_->view(); + + PkType lower_pk = lower_arg_.GetValue(); + PkType upper_pk = upper_arg_.GetValue(); + segment_->pk_binary_range(op_ctx_, + lower_pk, + expr_->lower_inclusive_, + upper_pk, + expr_->upper_inclusive_, + cache_view); + } + + TargetBitmap result; + result.append( + *cached_index_chunk_res_, current_data_global_pos_, real_batch_size); + MoveCursor(); + return std::make_shared(std::move(result), + TargetBitmap(real_batch_size, true)); +} + } // namespace exec } // namespace milvus diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index 0e5c6971f8..21daf95fb6 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -320,6 +320,10 @@ class PhyBinaryRangeFilterExpr : public SegmentExpr { VectorPtr ExecRangeVisitorImplForArray(EvalCtx& context); + template + VectorPtr + ExecRangeVisitorImplForPk(EvalCtx& context); + private: std::shared_ptr expr_; int64_t overflow_check_pos_{0}; diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index 32cf37360b..cdd613a8b6 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -191,12 +191,8 @@ PhyTermFilterExpr::InitPkCacheOffset() { } } - auto seg_offsets = segment_->search_ids(*id_array, query_timestamp_); cached_bits_.resize(active_count_, false); - for (const auto& offset : seg_offsets) { - auto _offset = (int64_t)offset.get(); - cached_bits_[_offset] = true; - } + segment_->search_ids(cached_bits_, *id_array); cached_bits_inited_ = true; } diff --git a/internal/core/src/plan/PlanNode.h b/internal/core/src/plan/PlanNode.h index 50a00d5909..0328ca522d 100644 --- a/internal/core/src/plan/PlanNode.h +++ b/internal/core/src/plan/PlanNode.h @@ -86,95 +86,6 @@ class PlanNode { using PlanNodePtr = std::shared_ptr; -class SegmentNode : public PlanNode { - public: - SegmentNode( - const PlanNodeId& id, - const std::shared_ptr& - segment) - : PlanNode(id), segment_(segment) { - } - - DataType - output_type() const override { - return DataType::ROW; - } - - std::vector> - sources() const override { - return {}; - } - - std::string_view - name() const override { - return "SegmentNode"; - } - - std::string - ToString() const override { - return "SegmentNode"; - } - - private: - std::shared_ptr segment_; -}; - -class ValuesNode : public PlanNode { - public: - ValuesNode(const PlanNodeId& id, - const std::vector& values, - bool parallelizeable = false) - : PlanNode(id), - values_{std::move(values)}, - output_type_(values[0]->type()) { - AssertInfo(!values.empty(), "ValueNode must has value"); - } - - ValuesNode(const PlanNodeId& id, - std::vector&& values, - bool parallelizeable = false) - : PlanNode(id), - values_{std::move(values)}, - output_type_(values[0]->type()) { - AssertInfo(!values.empty(), "ValueNode must has value"); - } - - DataType - output_type() const override { - return output_type_; - } - - const std::vector& - values() const { - return values_; - } - - std::vector - sources() const override { - return {}; - } - - bool - parallelizable() { - return parallelizable_; - } - - std::string_view - name() const override { - return "Values"; - } - - std::string - ToString() const override { - return "Values"; - } - - private: - DataType output_type_; - const std::vector values_; - bool parallelizable_; -}; - class FilterNode : public PlanNode { public: FilterNode(const PlanNodeId& id, diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index e6065a5d4b..7b7fd3fce5 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -44,7 +44,6 @@ ProtoParser::PlanOptionsFromProto( std::unique_ptr ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { - // TODO: add more buffs Assert(plan_node_proto.has_vector_anns()); auto& anns_proto = plan_node_proto.vector_anns(); @@ -289,9 +288,7 @@ ProtoParser::RetrievePlanNodeFromProto( sources)); sources = std::vector{plannode}; } else { - auto expr_parser = - parse_expr_to_filter_node(query.predicates()); - plannode = std::move(expr_parser); + plannode = parse_expr_to_filter_node(query.predicates()); sources = std::vector{plannode}; } } @@ -325,8 +322,8 @@ ProtoParser::CreatePlan(const proto::plan::PlanNode& plan_node_proto) { auto plan = std::make_unique(schema); auto plan_node = PlanNodeFromProto(plan_node_proto); - plan->tag2field_["$0"] = plan_node->search_info_.field_id_; plan->plan_node_ = std::move(plan_node); + plan->tag2field_["$0"] = plan->plan_node_->search_info_.field_id_; ExtractedPlanInfo extra_info(schema->size()); extra_info.add_involved_field(plan->plan_node_->search_info_.field_id_); plan->extra_info_opt_ = std::move(extra_info); diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 0ab3f7ed54..f2bd11de1c 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -1002,16 +1002,42 @@ ChunkedSegmentSealedImpl::check_search(const query::Plan* plan) const { } } -std::vector -ChunkedSegmentSealedImpl::search_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Timestamp timestamp) const { - if (!is_sorted_by_pk_) { - return insert_record_.search_pk(pk, timestamp); +void +ChunkedSegmentSealedImpl::search_pks(BitsetType& bitset, + const std::vector& pks) const { + if (pks.empty()) { + return; + } + BitsetTypeView bitset_view(bitset); + if (!is_sorted_by_pk_) { + for (auto& pk : pks) { + insert_record_.search_pk_range( + pk, proto::plan::OpType::Equal, bitset_view); + } + return; + } + + auto pk_field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); + AssertInfo(pk_field_id.get() != -1, "Primary key is -1"); + auto pk_column = get_column(pk_field_id); + AssertInfo(pk_column != nullptr, "primary key column not loaded"); + + switch (schema_->get_fields().at(pk_field_id).get_data_type()) { + case DataType::INT64: + search_pks_with_two_pointers_impl( + bitset_view, pks, pk_column); + break; + case DataType::VARCHAR: + search_pks_with_two_pointers_impl( + bitset_view, pks, pk_column); + break; + default: + ThrowInfo( + DataTypeInvalid, + fmt::format( + "unsupported type {}", + schema_->get_fields().at(pk_field_id).get_data_type())); } - return search_sorted_pk(op_ctx, pk, [this, timestamp](int64_t offset) { - return insert_record_.timestamps_[offset] <= timestamp; - }); } void @@ -1117,74 +1143,6 @@ ChunkedSegmentSealedImpl::search_batch_pks( } } -template -std::vector -ChunkedSegmentSealedImpl::search_sorted_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Condition condition) const { - auto pk_field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); - AssertInfo(pk_field_id.get() != -1, "Primary key is -1"); - auto pk_column = get_column(pk_field_id); - AssertInfo(pk_column != nullptr, "primary key column not loaded"); - std::vector pk_offsets; - switch (schema_->get_fields().at(pk_field_id).get_data_type()) { - case DataType::INT64: { - auto target = std::get(pk); - // get int64 pks - auto num_chunk = pk_column->num_chunks(); - for (int i = 0; i < num_chunk; ++i) { - auto pw = pk_column->DataOfChunk(op_ctx, i); - auto src = reinterpret_cast(pw.get()); - auto chunk_row_num = pk_column->chunk_row_nums(i); - auto it = std::lower_bound( - src, - src + chunk_row_num, - target, - [](const int64_t& elem, const int64_t& value) { - return elem < value; - }); - auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i); - for (; it != src + chunk_row_num && *it == target; ++it) { - auto offset = it - src + num_rows_until_chunk; - if (condition(offset)) { - pk_offsets.emplace_back(offset); - } - } - } - break; - } - case DataType::VARCHAR: { - auto target = std::get(pk); - // get varchar pks - auto num_chunk = pk_column->num_chunks(); - for (int i = 0; i < num_chunk; ++i) { - // TODO @xiaocai2333, @sunby: chunk need to record the min/max. - auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i); - auto pw = pk_column->GetChunk(op_ctx, i); - auto string_chunk = static_cast(pw.get()); - auto offset = string_chunk->binary_search_string(target); - for (; offset != -1 && offset < string_chunk->RowNums() && - string_chunk->operator[](offset) == target; - ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - pk_offsets.emplace_back(segment_offset); - } - } - } - break; - } - default: { - ThrowInfo( - DataTypeInvalid, - fmt::format( - "unsupported type {}", - schema_->get_fields().at(pk_field_id).get_data_type())); - } - } - return pk_offsets; -} - void ChunkedSegmentSealedImpl::pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, @@ -1195,199 +1153,82 @@ ChunkedSegmentSealedImpl::pk_range(milvus::OpContext* op_ctx, return; } - search_sorted_pk_range( - op_ctx, op, pk, bitset, [](int64_t offset) { return true; }); + search_sorted_pk_range(op_ctx, op, pk, bitset); } -template void ChunkedSegmentSealedImpl::search_sorted_pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, const PkType& pk, - BitsetTypeView& bitset, - Condition condition) const { + BitsetTypeView& bitset) const { auto pk_field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); AssertInfo(pk_field_id.get() != -1, "Primary key is -1"); auto pk_column = get_column(pk_field_id); AssertInfo(pk_column != nullptr, "primary key column not loaded"); switch (schema_->get_fields().at(pk_field_id).get_data_type()) { - case DataType::INT64: { - // get int64 pks - auto target = std::get(pk); - - auto num_chunk = pk_column->num_chunks(); - for (int i = 0; i < num_chunk; ++i) { - auto pw = pk_column->DataOfChunk(op_ctx, i); - auto src = reinterpret_cast(pw.get()); - auto chunk_row_num = pk_column->chunk_row_nums(i); - if (op == proto::plan::OpType::GreaterEqual) { - auto it = std::lower_bound( - src, - src + chunk_row_num, - target, - [](const int64_t& elem, const int64_t& value) { - return elem < value; - }); - auto num_rows_until_chunk = - pk_column->GetNumRowsUntilChunk(i); - for (; it != src + chunk_row_num; ++it) { - auto offset = it - src + num_rows_until_chunk; - if (condition(offset)) { - bitset[offset] = true; - } - } - } else if (op == proto::plan::OpType::GreaterThan) { - auto it = std::upper_bound( - src, - src + chunk_row_num, - target, - [](const int64_t& elem, const int64_t& value) { - return elem < value; - }); - auto num_rows_until_chunk = - pk_column->GetNumRowsUntilChunk(i); - for (; it != src + chunk_row_num; ++it) { - auto offset = it - src + num_rows_until_chunk; - if (condition(offset)) { - bitset[offset] = true; - } - } - } else if (op == proto::plan::OpType::LessEqual) { - auto it = std::upper_bound( - src, - src + chunk_row_num, - target, - [](const int64_t& elem, const int64_t& value) { - return elem < value; - }); - if (it == src) { - break; - } - auto num_rows_until_chunk = - pk_column->GetNumRowsUntilChunk(i); - for (auto ptr = src; ptr < it; ++ptr) { - auto offset = ptr - src + num_rows_until_chunk; - if (condition(offset)) { - bitset[offset] = true; - } - } - } else if (op == proto::plan::OpType::LessThan) { - auto it = - std::lower_bound(src, src + chunk_row_num, target); - if (it == src) { - break; - } - auto num_rows_until_chunk = - pk_column->GetNumRowsUntilChunk(i); - for (auto ptr = src; ptr < it; ++ptr) { - auto offset = ptr - src + num_rows_until_chunk; - if (condition(offset)) { - bitset[offset] = true; - } - } - } else if (op == proto::plan::OpType::Equal) { - auto it = std::lower_bound( - src, - src + chunk_row_num, - target, - [](const int64_t& elem, const int64_t& value) { - return elem < value; - }); - auto num_rows_until_chunk = - pk_column->GetNumRowsUntilChunk(i); - for (; it != src + chunk_row_num && *it == target; ++it) { - auto offset = it - src + num_rows_until_chunk; - if (condition(offset)) { - bitset[offset] = true; - } - } - if (it != src + chunk_row_num && *it > target) { - break; - } - } else { - ThrowInfo(ErrorCode::Unsupported, - fmt::format("unsupported op type {}", op)); - } - } + case DataType::INT64: + search_sorted_pk_range_impl( + op, std::get(pk), pk_column, bitset); break; - } - case DataType::VARCHAR: { - // get varchar pks - auto target = std::get(pk); - - auto num_chunk = pk_column->num_chunks(); - for (int i = 0; i < num_chunk; ++i) { - auto num_rows_until_chunk = pk_column->GetNumRowsUntilChunk(i); - auto pw = pk_column->GetChunk(op_ctx, i); - auto string_chunk = static_cast(pw.get()); - - if (op == proto::plan::OpType::Equal) { - auto offset = string_chunk->lower_bound_string(target); - for (; offset < string_chunk->RowNums() && - string_chunk->operator[](offset) == target; - ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - bitset[segment_offset] = true; - } - } - if (offset < string_chunk->RowNums() && - string_chunk->operator[](offset) > target) { - break; - } - } else if (op == proto::plan::OpType::GreaterEqual) { - auto offset = string_chunk->lower_bound_string(target); - for (; offset < string_chunk->RowNums(); ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - bitset[segment_offset] = true; - } - } - } else if (op == proto::plan::OpType::GreaterThan) { - auto offset = string_chunk->upper_bound_string(target); - for (; offset < string_chunk->RowNums(); ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - bitset[segment_offset] = true; - } - } - } else if (op == proto::plan::OpType::LessEqual) { - auto pos = string_chunk->upper_bound_string(target); - if (pos == 0) { - break; - } - for (auto offset = 0; offset < pos; ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - bitset[segment_offset] = true; - } - } - } else if (op == proto::plan::OpType::LessThan) { - auto pos = string_chunk->lower_bound_string(target); - if (pos == 0) { - break; - } - for (auto offset = 0; offset < pos; ++offset) { - auto segment_offset = offset + num_rows_until_chunk; - if (condition(segment_offset)) { - bitset[segment_offset] = true; - } - } - } else { - ThrowInfo(ErrorCode::Unsupported, - fmt::format("unsupported op type {}", op)); - } - } + case DataType::VARCHAR: + search_sorted_pk_range_impl( + op, std::get(pk), pk_column, bitset); break; - } - default: { + default: + ThrowInfo( + DataTypeInvalid, + fmt::format( + "unsupported type {}", + schema_->get_fields().at(pk_field_id).get_data_type())); + } +} + +void +ChunkedSegmentSealedImpl::pk_binary_range(milvus::OpContext* op_ctx, + const PkType& lower_pk, + bool lower_inclusive, + const PkType& upper_pk, + bool upper_inclusive, + BitsetTypeView& bitset) const { + if (!is_sorted_by_pk_) { + // For unsorted segments, use the InsertRecord's binary range search + insert_record_.search_pk_binary_range( + lower_pk, lower_inclusive, upper_pk, upper_inclusive, bitset); + return; + } + + // For sorted segments, use binary search + auto pk_field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); + AssertInfo(pk_field_id.get() != -1, "Primary key is -1"); + auto pk_column = get_column(pk_field_id); + AssertInfo(pk_column != nullptr, "primary key column not loaded"); + + switch (schema_->get_fields().at(pk_field_id).get_data_type()) { + case DataType::INT64: + search_sorted_pk_binary_range_impl( + std::get(lower_pk), + lower_inclusive, + std::get(upper_pk), + upper_inclusive, + pk_column, + bitset); + break; + case DataType::VARCHAR: + search_sorted_pk_binary_range_impl( + std::get(lower_pk), + lower_inclusive, + std::get(upper_pk), + upper_inclusive, + pk_column, + bitset); + break; + default: ThrowInfo( DataTypeInvalid, fmt::format( "unsupported type {}", schema_->get_fields().at(pk_field_id).get_data_type())); - } } } @@ -2197,9 +2038,9 @@ ChunkedSegmentSealedImpl::GetFieldDataType(milvus::FieldId field_id) const { return field_meta.get_data_type(); } -std::vector -ChunkedSegmentSealedImpl::search_ids(const IdArray& id_array, - Timestamp timestamp) const { +void +ChunkedSegmentSealedImpl::search_ids(BitsetType& bitset, + const IdArray& id_array) const { auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); AssertInfo(field_id.get() != -1, "Primary key is -1"); auto& field_meta = schema_->operator[](field_id); @@ -2208,16 +2049,7 @@ ChunkedSegmentSealedImpl::search_ids(const IdArray& id_array, std::vector pks(ids_size); ParsePksFromIDs(pks, data_type, id_array); - std::vector res_offsets; - res_offsets.reserve(pks.size()); - this->search_batch_pks( - pks, - [=](const size_t idx) { return timestamp; }, - true, - [&](const SegOffset offset, const Timestamp ts) { - res_offsets.push_back(offset); - }); - return std::move(res_offsets); + this->search_pks(bitset, pks); } SegcoreError @@ -2266,14 +2098,6 @@ ChunkedSegmentSealedImpl::Delete(int64_t size, return SegcoreError::success(); } -std::string -ChunkedSegmentSealedImpl::debug() const { - std::string log_str; - log_str += "Sealed\n"; - log_str += "\n"; - return log_str; -} - void ChunkedSegmentSealedImpl::LoadSegmentMeta( const proto::segcore::LoadSegmentMeta& segment_meta) { diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h index 22ef254b45..0ca72718c8 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h @@ -204,30 +204,25 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { const Schema& get_schema() const override; - std::vector - search_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Timestamp timestamp) const override; - - template - std::vector - search_sorted_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Condition condition) const; - void pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, const PkType& pk, BitsetTypeView& bitset) const override; - template void search_sorted_pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, const PkType& pk, - BitsetTypeView& bitset, - Condition condition) const; + BitsetTypeView& bitset) const; + + void + pk_binary_range(milvus::OpContext* op_ctx, + const PkType& lower_pk, + bool lower_inclusive, + const PkType& upper_pk, + bool upper_inclusive, + BitsetTypeView& bitset) const override; std::unique_ptr get_vector(milvus::OpContext* op_ctx, @@ -246,6 +241,9 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { return true; } + void + search_pks(BitsetType& bitset, const std::vector& pks) const; + void search_batch_pks( const std::vector& pks, @@ -275,9 +273,6 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { int64_t num_rows_until_chunk(FieldId field_id, int64_t chunk_id) const override; - std::string - debug() const override; - SegcoreError Delete(int64_t size, const IdArray* pks, @@ -390,6 +385,422 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { void load_system_field_internal(FieldId field_id, FieldDataInfo& data); + template + void + search_sorted_pk_range_impl( + proto::plan::OpType op, + const PK& target, + const std::shared_ptr& pk_column, + BitsetTypeView& bitset) const { + const auto num_chunk = pk_column->num_chunks(); + if (num_chunk == 0) { + return; + } + auto all_chunk_pins = pk_column->GetAllChunks(nullptr); + + if (op == proto::plan::OpType::Equal) { + // find first occurrence + auto [chunk_id, in_chunk_offset, exact_match] = + this->pk_lower_bound( + target, pk_column.get(), all_chunk_pins, 0); + + if (exact_match) { + // find last occurrence + auto [last_chunk_id, last_in_chunk_offset] = + this->find_last_pk_position(target, + pk_column.get(), + all_chunk_pins, + chunk_id, + in_chunk_offset); + + auto start_idx = + pk_column->GetNumRowsUntilChunk(chunk_id) + in_chunk_offset; + auto end_idx = pk_column->GetNumRowsUntilChunk(last_chunk_id) + + last_in_chunk_offset; + + bitset.set(start_idx, end_idx - start_idx + 1, true); + } + } else if (op == proto::plan::OpType::GreaterEqual || + op == proto::plan::OpType::GreaterThan) { + auto [chunk_id, in_chunk_offset, exact_match] = + this->pk_lower_bound( + target, pk_column.get(), all_chunk_pins, 0); + + if (chunk_id != -1) { + int64_t start_idx = + pk_column->GetNumRowsUntilChunk(chunk_id) + in_chunk_offset; + if (exact_match && op == proto::plan::OpType::GreaterThan) { + auto [last_chunk_id, last_in_chunk_offset] = + this->find_last_pk_position(target, + pk_column.get(), + all_chunk_pins, + chunk_id, + in_chunk_offset); + start_idx = pk_column->GetNumRowsUntilChunk(last_chunk_id) + + last_in_chunk_offset + 1; + } + if (start_idx < bitset.size()) { + bitset.set(start_idx, bitset.size() - start_idx, true); + } + } + } else if (op == proto::plan::OpType::LessEqual || + op == proto::plan::OpType::LessThan) { + auto [chunk_id, in_chunk_offset, exact_match] = + this->pk_lower_bound( + target, pk_column.get(), all_chunk_pins, 0); + + int64_t end_idx; + if (chunk_id == -1) { + end_idx = bitset.size(); + } else if (op == proto::plan::OpType::LessEqual && exact_match) { + auto [last_chunk_id, last_in_chunk_offset] = + this->find_last_pk_position(target, + pk_column.get(), + all_chunk_pins, + chunk_id, + in_chunk_offset); + end_idx = pk_column->GetNumRowsUntilChunk(last_chunk_id) + + last_in_chunk_offset + 1; + } else { + end_idx = + pk_column->GetNumRowsUntilChunk(chunk_id) + in_chunk_offset; + } + + if (end_idx > 0) { + bitset.set(0, end_idx, true); + } + } else { + ThrowInfo(ErrorCode::Unsupported, + fmt::format("unsupported op type {}", op)); + } + } + + template + void + search_sorted_pk_binary_range_impl( + const PK& lower_val, + bool lower_inclusive, + const PK& upper_val, + bool upper_inclusive, + const std::shared_ptr& pk_column, + BitsetTypeView& bitset) const { + const auto num_chunk = pk_column->num_chunks(); + if (num_chunk == 0) { + return; + } + auto all_chunk_pins = pk_column->GetAllChunks(nullptr); + + // Find the lower bound position (first value >= lower_val or > lower_val) + auto [lower_chunk_id, lower_in_chunk_offset, lower_exact_match] = + this->pk_lower_bound( + lower_val, pk_column.get(), all_chunk_pins, 0); + + int64_t start_idx = 0; + if (lower_chunk_id != -1) { + start_idx = pk_column->GetNumRowsUntilChunk(lower_chunk_id) + + lower_in_chunk_offset; + // If lower_inclusive is false and we found an exact match, skip all equal values + if (!lower_inclusive && lower_exact_match) { + auto [last_chunk_id, last_in_chunk_offset] = + this->find_last_pk_position(lower_val, + pk_column.get(), + all_chunk_pins, + lower_chunk_id, + lower_in_chunk_offset); + start_idx = pk_column->GetNumRowsUntilChunk(last_chunk_id) + + last_in_chunk_offset + 1; + } + } else { + // lower_val is greater than all values, no results + return; + } + + // Find the upper bound position (first value >= upper_val or > upper_val) + auto [upper_chunk_id, upper_in_chunk_offset, upper_exact_match] = + this->pk_lower_bound( + upper_val, pk_column.get(), all_chunk_pins, 0); + + int64_t end_idx = 0; + if (upper_chunk_id == -1) { + // upper_val is greater than all values, include all from start_idx to end + end_idx = bitset.size(); + } else { + // If upper_inclusive is true and we found an exact match, include all equal values + if (upper_inclusive && upper_exact_match) { + auto [last_chunk_id, last_in_chunk_offset] = + this->find_last_pk_position(upper_val, + pk_column.get(), + all_chunk_pins, + upper_chunk_id, + upper_in_chunk_offset); + end_idx = pk_column->GetNumRowsUntilChunk(last_chunk_id) + + last_in_chunk_offset + 1; + } else { + // upper_inclusive is false or no exact match + // In both cases, end at the position of first value >= upper_val + end_idx = pk_column->GetNumRowsUntilChunk(upper_chunk_id) + + upper_in_chunk_offset; + } + } + + // Set bits from start_idx to end_idx - 1 + if (start_idx < end_idx) { + bitset.set(start_idx, end_idx - start_idx, true); + } + } + + template + void + search_pks_with_two_pointers_impl( + BitsetTypeView& bitset, + const std::vector& pks, + const std::shared_ptr& pk_column) const { + // TODO: we should sort pks during plan generation + std::vector sorted_pks; + sorted_pks.reserve(pks.size()); + for (const auto& pk : pks) { + sorted_pks.push_back(std::get(pk)); + } + std::sort(sorted_pks.begin(), sorted_pks.end()); + + auto all_chunk_pins = pk_column->GetAllChunks(nullptr); + + size_t pk_idx = 0; + int last_chunk_id = 0; + + while (pk_idx < sorted_pks.size()) { + const auto& target_pk = sorted_pks[pk_idx]; + + // find the first occurrence of target_pk + auto [chunk_id, in_chunk_offset, exact_match] = + this->pk_lower_bound( + target_pk, pk_column.get(), all_chunk_pins, last_chunk_id); + + if (chunk_id == -1) { + // All remaining PKs are greater than all values in pk_column + break; + } + + if (exact_match) { + // Found exact match, find the last occurrence + auto [last_chunk_id_found, last_in_chunk_offset] = + this->find_last_pk_position(target_pk, + pk_column.get(), + all_chunk_pins, + chunk_id, + in_chunk_offset); + + // Mark all occurrences from first to last position using global indices + auto start_idx = + pk_column->GetNumRowsUntilChunk(chunk_id) + in_chunk_offset; + auto end_idx = + pk_column->GetNumRowsUntilChunk(last_chunk_id_found) + + last_in_chunk_offset; + + bitset.set(start_idx, end_idx - start_idx + 1, true); + last_chunk_id = last_chunk_id_found; + } + + while (pk_idx < sorted_pks.size() && + sorted_pks[pk_idx] == target_pk) { + pk_idx++; + } + } + } + + // Binary search to find lower_bound of pk in pk_column starting from from_chunk_id + // Returns: (chunk_id, in_chunk_offset, exists) + // - chunk_id: the chunk containing the first value >= pk + // - in_chunk_offset: offset of the first value >= pk in that chunk + // - exists: true if found an exact match (value == pk), false otherwise + // - If pk doesn't exist, returns the position of first value > pk with exists=false + // - If pk is greater than all values, returns {-1, -1, false} + template + std::tuple + pk_lower_bound(const PK& pk, + const ChunkedColumnInterface* pk_column, + const std::vector>& all_chunk_pins, + int from_chunk_id = 0) const { + const auto num_chunk = pk_column->num_chunks(); + + if (from_chunk_id >= num_chunk) { + return {-1, -1, false}; // Invalid starting chunk + } + + using PKViewType = std::conditional_t, + int64_t, + std::string_view>; + + auto get_val_view = [&](int chunk_id, + int in_chunk_offset) -> PKViewType { + auto pw = all_chunk_pins[chunk_id]; + if constexpr (std::is_same_v) { + auto src = + reinterpret_cast(pw.get()->RawData()); + return src[in_chunk_offset]; + } else { + auto string_chunk = static_cast(pw.get()); + return string_chunk->operator[](in_chunk_offset); + } + }; + + // Binary search at chunk level to find the first chunk that might contain pk + int left_chunk_id = from_chunk_id; + int right_chunk_id = num_chunk - 1; + int target_chunk_id = -1; + + while (left_chunk_id <= right_chunk_id) { + int mid_chunk_id = + left_chunk_id + (right_chunk_id - left_chunk_id) / 2; + auto chunk_row_num = pk_column->chunk_row_nums(mid_chunk_id); + + PKViewType min_val = get_val_view(mid_chunk_id, 0); + PKViewType max_val = get_val_view(mid_chunk_id, chunk_row_num - 1); + + if (pk >= min_val && pk <= max_val) { + // pk might be in this chunk + target_chunk_id = mid_chunk_id; + break; + } else if (pk < min_val) { + // pk is before this chunk, could be in an earlier chunk + target_chunk_id = mid_chunk_id; // This chunk has values >= pk + right_chunk_id = mid_chunk_id - 1; + } else { + // pk is after this chunk, search in later chunks + left_chunk_id = mid_chunk_id + 1; + } + } + + // If no suitable chunk found, check if we need the first position after all chunks + if (target_chunk_id == -1) { + if (left_chunk_id >= num_chunk) { + // pk is greater than all values + return {-1, -1, false}; + } + target_chunk_id = left_chunk_id; + } + + // Binary search within the target chunk to find lower_bound position + auto chunk_row_num = pk_column->chunk_row_nums(target_chunk_id); + int left_offset = 0; + int right_offset = chunk_row_num; + + while (left_offset < right_offset) { + int mid_offset = left_offset + (right_offset - left_offset) / 2; + PKViewType mid_val = get_val_view(target_chunk_id, mid_offset); + + if (mid_val < pk) { + left_offset = mid_offset + 1; + } else { + right_offset = mid_offset; + } + } + + // Check if we found a valid position + if (left_offset < chunk_row_num) { + // Found position within current chunk + PKViewType found_val = get_val_view(target_chunk_id, left_offset); + bool exact_match = (found_val == pk); + return {target_chunk_id, left_offset, exact_match}; + } else { + // Position is beyond current chunk, try next chunk + if (target_chunk_id + 1 < num_chunk) { + // Next chunk exists, return its first position + // Check if the first value in next chunk equals pk + PKViewType next_val = get_val_view(target_chunk_id + 1, 0); + bool exact_match = (next_val == pk); + return {target_chunk_id + 1, 0, exact_match}; + } else { + // No more chunks, pk is greater than all values + return {-1, -1, false}; + } + } + } + + // Find the last occurrence position of pk starting from a known first occurrence + // Parameters: + // - pk: the primary key to search for + // - pk_column: the primary key column + // - first_chunk_id: chunk id of the first occurrence (from pk_lower_bound) + // - first_in_chunk_offset: offset in chunk of the first occurrence (from pk_lower_bound) + // Returns: (last_chunk_id, last_in_chunk_offset) + // - The position of the last occurrence of pk + // Note: This function assumes pk exists and linearly scans forward. + // It's efficient when pk has few duplicates. + template + std::tuple + find_last_pk_position(const PK& pk, + const ChunkedColumnInterface* pk_column, + const std::vector>& all_chunk_pins, + int first_chunk_id, + int first_in_chunk_offset) const { + const auto num_chunk = pk_column->num_chunks(); + + using PKViewType = std::conditional_t, + int64_t, + std::string_view>; + + auto get_val_view = [&](int chunk_id, + int in_chunk_offset) -> PKViewType { + auto pw = all_chunk_pins[chunk_id]; + if constexpr (std::is_same_v) { + auto src = + reinterpret_cast(pw.get()->RawData()); + return src[in_chunk_offset]; + } else { + auto string_chunk = static_cast(pw.get()); + return string_chunk->operator[](in_chunk_offset); + } + }; + + int last_chunk_id = first_chunk_id; + int last_offset = first_in_chunk_offset; + + // Linear scan forward in current chunk + auto chunk_row_num = pk_column->chunk_row_nums(first_chunk_id); + for (int offset = first_in_chunk_offset + 1; offset < chunk_row_num; + offset++) { + PKViewType curr_val = get_val_view(first_chunk_id, offset); + if (curr_val == pk) { + last_offset = offset; + } else { + // Found first value != pk, done + return {last_chunk_id, last_offset}; + } + } + + // Continue scanning in subsequent chunks + for (int chunk_id = first_chunk_id + 1; chunk_id < num_chunk; + chunk_id++) { + auto curr_chunk_row_num = pk_column->chunk_row_nums(chunk_id); + + // Check first value in this chunk + PKViewType first_val = get_val_view(chunk_id, 0); + if (first_val != pk) { + // This chunk doesn't contain pk anymore + return {last_chunk_id, last_offset}; + } + + // Update last position and scan this chunk + last_chunk_id = chunk_id; + last_offset = 0; + + for (int offset = 1; offset < curr_chunk_row_num; offset++) { + PKViewType curr_val = get_val_view(chunk_id, offset); + if (curr_val == pk) { + last_offset = offset; + } else { + // Found first value != pk + return {last_chunk_id, last_offset}; + } + } + // All values in this chunk equal pk, continue to next chunk + } + + // Scanned all chunks, return last found position + return {last_chunk_id, last_offset}; + } + template static void bulk_subscript_impl(milvus::OpContext* op_ctx, @@ -481,8 +892,8 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { return system_ready_count_ == 1; } - std::vector - search_ids(const IdArray& id_array, Timestamp timestamp) const override; + void + search_ids(BitsetType& bitset, const IdArray& id_array) const override; void LoadVecIndex(const LoadIndexInfo& info); diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index 80a34f0ba4..dbb7b6c127 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -514,6 +514,33 @@ class InsertRecordSealed { pk2offset_->find_range(pk, op, bitset, condition); } + void + search_pk_binary_range(const PkType& lower_pk, + bool lower_inclusive, + const PkType& upper_pk, + bool upper_inclusive, + BitsetTypeView& bitset) const { + auto lower_op = lower_inclusive ? proto::plan::OpType::GreaterEqual + : proto::plan::OpType::GreaterThan; + auto upper_op = upper_inclusive ? proto::plan::OpType::LessEqual + : proto::plan::OpType::LessThan; + + BitsetType upper_result(bitset.size()); + auto upper_view = upper_result.view(); + + // values >= lower_pk (or > lower_pk if not inclusive) + pk2offset_->find_range( + lower_pk, lower_op, bitset, [](int64_t offset) { return true; }); + + // values <= upper_pk (or < upper_pk if not inclusive) + pk2offset_->find_range( + upper_pk, upper_op, upper_view, [](int64_t offset) { + return true; + }); + + bitset &= upper_result; + } + void insert_pks(milvus::DataType data_type, ChunkedColumnInterface* data) { std::lock_guard lck(shared_mutex_); diff --git a/internal/core/src/segcore/SegmentGrowing.h b/internal/core/src/segcore/SegmentGrowing.h index 06e1813652..7292c53b44 100644 --- a/internal/core/src/segcore/SegmentGrowing.h +++ b/internal/core/src/segcore/SegmentGrowing.h @@ -39,6 +39,17 @@ class SegmentGrowing : public SegmentInternalInterface { return SegmentType::Growing; } + void + pk_binary_range(milvus::OpContext* op_ctx, + const PkType& lower_pk, + bool lower_inclusive, + const PkType& upper_pk, + bool upper_inclusive, + BitsetTypeView& bitset) const override { + ThrowInfo(ErrorCode::Unsupported, + "pk_binary_range is not supported for growing segment"); + } + // virtual int64_t // PreDelete(int64_t size) = 0; diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 1398f9f906..a6980e3944 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -1166,9 +1166,9 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, } } -std::vector -SegmentGrowingImpl::search_ids(const IdArray& id_array, - Timestamp timestamp) const { +void +SegmentGrowingImpl::search_ids(BitsetType& bitset, + const IdArray& id_array) const { auto field_id = schema_->get_primary_field_id().value_or(FieldId(-1)); AssertInfo(field_id.get() != -1, "Primary key is -1"); auto& field_meta = schema_->operator[](field_id); @@ -1177,20 +1177,11 @@ SegmentGrowingImpl::search_ids(const IdArray& id_array, std::vector pks(ids_size); ParsePksFromIDs(pks, data_type, id_array); - std::vector res_offsets; - res_offsets.reserve(pks.size()); + BitsetTypeView bitset_view(bitset); for (auto& pk : pks) { - auto segOffsets = insert_record_.search_pk(pk, timestamp); - for (auto offset : segOffsets) { - res_offsets.push_back(offset); - } + insert_record_.search_pk_range( + pk, proto::plan::OpType::Equal, bitset_view); } - return std::move(res_offsets); -} - -std::string -SegmentGrowingImpl::debug() const { - return "Growing\n"; } int64_t diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index d3a018c671..a39baf950c 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -68,9 +68,6 @@ class SegmentGrowingImpl : public SegmentGrowing { void LoadFieldData(const LoadFieldDataInfo& info) override; - std::string - debug() const override; - int64_t get_segment_id() const override { return id_; @@ -375,8 +372,8 @@ class SegmentGrowingImpl : public SegmentGrowing { int64_t ins_barrier, Timestamp timestamp) const override; - std::vector - search_ids(const IdArray& id_array, Timestamp timestamp) const override; + void + search_ids(BitsetType& bitset, const IdArray& id_array) const override; bool HasIndex(FieldId field_id) const { @@ -457,13 +454,6 @@ class SegmentGrowingImpl : public SegmentGrowing { return false; } - std::vector - search_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Timestamp timestamp) const override { - return insert_record_.search_pk(pk, timestamp); - } - void pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index f103255a70..ced32bedaa 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -337,9 +337,6 @@ class SegmentInternalInterface : public SegmentInterface { virtual bool HasIndex(FieldId field_id) const = 0; - virtual std::string - debug() const = 0; - int64_t get_real_count() const override; @@ -432,12 +429,14 @@ class SegmentInternalInterface : public SegmentInterface { /** * search offset by possible pk values and mvcc timestamp * + * @param bitset The final bitset after id array filtering, + * `false` means that the entity will be filtered out. * @param id_array possible pk values - * @param timestamp mvcc timestamp - * @return all the hit entries in vector of offsets + * this interface is used for internal expression calculation, + * so no need timestamp parameter, mvcc node prove the timestamp is already filtered. */ - virtual std::vector - search_ids(const IdArray& id_array, Timestamp timestamp) const = 0; + virtual void + search_ids(BitsetType& bitset, const IdArray& id_array) const = 0; /** * Apply timestamp filtering on bitset, the query can't see an entity whose @@ -575,17 +574,20 @@ class SegmentInternalInterface : public SegmentInterface { int64_t count, const std::vector& dynamic_field_names) const = 0; - virtual std::vector - search_pk(milvus::OpContext* op_ctx, - const PkType& pk, - Timestamp timestamp) const = 0; - virtual void pk_range(milvus::OpContext* op_ctx, proto::plan::OpType op, const PkType& pk, BitsetTypeView& bitset) const = 0; + virtual void + pk_binary_range(milvus::OpContext* op_ctx, + const PkType& lower_pk, + bool lower_inclusive, + const PkType& upper_pk, + bool upper_inclusive, + BitsetTypeView& bitset) const = 0; + virtual GEOSContextHandle_t get_ctx() const { return ctx_; diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index abec32b089..b7d5293c1b 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -117,6 +117,33 @@ install(TARGETS all_tests DESTINATION unittest) add_subdirectory(bench) add_subdirectory(test_json_stats) +# Optionally include external scalar-benchmark project +option(ENABLE_SCALAR_BENCH "Enable fetching and building scalar-benchmark" OFF) +set(SCALAR_BENCHMARK_GIT_URL "https://github.com/zilliztech/scalar-benchmark" CACHE STRING "Scalar benchmark git repo URL") +set(SCALAR_BENCHMARK_GIT_TAG "main" CACHE STRING "Scalar benchmark git tag/branch") + +if (ENABLE_SCALAR_BENCH) + include(FetchContent) + if (DEFINED SCALAR_BENCHMARK_SOURCE_DIR AND EXISTS ${SCALAR_BENCHMARK_SOURCE_DIR}/CMakeLists.txt) + message(STATUS "Using local scalar-benchmark from ${SCALAR_BENCHMARK_SOURCE_DIR}") + add_subdirectory(${SCALAR_BENCHMARK_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/scalar-benchmark-build) + else() + message(STATUS "Fetching scalar-benchmark from ${SCALAR_BENCHMARK_GIT_URL} (${SCALAR_BENCHMARK_GIT_TAG})") + FetchContent_Declare( + scalar_benchmark + GIT_REPOSITORY ${SCALAR_BENCHMARK_GIT_URL} + GIT_TAG ${SCALAR_BENCHMARK_GIT_TAG} + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/scalar-benchmark-src + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/scalar-benchmark-build + ) + FetchContent_GetProperties(scalar_benchmark) + if (NOT scalar_benchmark_POPULATED) + FetchContent_Populate(scalar_benchmark) + add_subdirectory(${scalar_benchmark_SOURCE_DIR} ${scalar_benchmark_BINARY_DIR}) + endif() + endif() +endif() + # bitset unit test include(CheckCXXCompilerFlag) include(CheckIncludeFileCXX) diff --git a/internal/core/unittest/bench/CMakeLists.txt b/internal/core/unittest/bench/CMakeLists.txt index 68bed1644c..960e62460a 100644 --- a/internal/core/unittest/bench/CMakeLists.txt +++ b/internal/core/unittest/bench/CMakeLists.txt @@ -37,4 +37,4 @@ target_link_libraries(indexbuilder_bench pthread ) -target_link_libraries(indexbuilder_bench benchmark_main) +target_link_libraries(indexbuilder_bench benchmark_main) \ No newline at end of file diff --git a/internal/core/unittest/bench/bench_search_pk.cpp b/internal/core/unittest/bench/bench_search_pk.cpp deleted file mode 100644 index f65244133f..0000000000 --- a/internal/core/unittest/bench/bench_search_pk.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include -#include -#include -#include "common/type_c.h" -#include "segcore/segment_c.h" -#include "segcore/SegmentGrowing.h" -#include "segcore/SegmentSealed.h" -#include "test_cachinglayer/cachinglayer_test_utils.h" -#include "test_utils/DataGen.h" -#include "test_utils/storage_test_utils.h" - -using namespace milvus; -using namespace milvus::query; -using namespace milvus::segcore; - -static int dim = 768; \ No newline at end of file diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index c8ac1d475a..f607093b16 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -2280,31 +2280,6 @@ TEST(Sealed, QueryAllNullableFields) { EXPECT_EQ(float_array_result->valid_data_size(), dataset_size); } -TEST(Sealed, SearchSortedPk) { - auto schema = std::make_shared(); - auto varchar_pk_field = schema->AddDebugField("pk", DataType::VARCHAR); - schema->set_primary_field_id(varchar_pk_field); - auto segment_sealed = CreateSealedSegment( - schema, nullptr, 999, SegcoreConfig::default_config(), true); - auto segment = - dynamic_cast(segment_sealed.get()); - - int64_t dataset_size = 1000; - auto dataset = DataGen(schema, dataset_size, 42, 0, 10); - LoadGeneratedDataIntoSegment(dataset, segment); - - auto pk_values = dataset.get_col(varchar_pk_field); - auto offsets = - segment->search_pk(nullptr, PkType(pk_values[100]), Timestamp(99999)); - EXPECT_EQ(10, offsets.size()); - EXPECT_EQ(100, offsets[0].get()); - - auto offsets2 = - segment->search_pk(nullptr, PkType(pk_values[100]), int64_t(105)); - EXPECT_EQ(6, offsets2.size()); - EXPECT_EQ(100, offsets2[0].get()); -} - using VectorArrayTestParam = std::tuple; diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index b0f5b8ddba..99718c79ad 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -158,8 +158,10 @@ PrepareSingleFieldInsertBinlog(int64_t collection_id, for (auto i = 0; i < field_datas.size(); ++i) { auto& field_data = field_datas[i]; row_count += field_data->Length(); - auto file = - "./data/test" + std::to_string(field_id) + "/" + std::to_string(i); + auto file = "./data/test/" + std::to_string(collection_id) + "/" + + std::to_string(partition_id) + "/" + + std::to_string(segment_id) + "/" + + std::to_string(field_id) + "/" + std::to_string(i); files.push_back(file); row_counts.push_back(field_data->Length()); auto payload_reader =