diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 886e5bb9b2..9372e4c031 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -205,4 +205,12 @@ Join(const std::vector& items, const std::string& delimiter) { return ss.str(); } +inline std::string +GetCommonPrefix(const std::string& str1, const std::string& str2) { + size_t len = std::min(str1.length(), str2.length()); + size_t i = 0; + while (i < len && str1[i] == str2[i]) ++i; + return str1.substr(0, i); +} + } // namespace milvus diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index 9360940e70..67e09cf115 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -389,31 +389,64 @@ const TargetBitmap StringIndexMarisa::Range(std::string value, OpType op) { auto count = Count(); TargetBitmap bitset(count); + std::vector ids; marisa::Agent agent; - for (size_t offset = 0; offset < count; ++offset) { - agent.set_query(str_ids_[offset]); - trie_.reverse_lookup(agent); - std::string raw_data(agent.key().ptr(), agent.key().length()); - bool set = false; - switch (op) { - case OpType::LessThan: - set = raw_data.compare(value) < 0; - break; - case OpType::LessEqual: - set = raw_data.compare(value) <= 0; - break; - case OpType::GreaterThan: - set = raw_data.compare(value) > 0; - break; - case OpType::GreaterEqual: - set = raw_data.compare(value) >= 0; - break; - default: - throw SegcoreError(OpTypeInvalid, - fmt::format("Invalid OperatorType: {}", - static_cast(op))); + switch (op) { + case OpType::GreaterThan: { + while (trie_.predictive_search(agent)) { + auto key = std::string(agent.key().ptr(), agent.key().length()); + if (key > value) { + ids.push_back(agent.key().id()); + break; + } + }; + while (trie_.predictive_search(agent)) { + ids.push_back(agent.key().id()); + } + break; } - if (set) { + case OpType::GreaterEqual: { + while (trie_.predictive_search(agent)) { + auto key = std::string(agent.key().ptr(), agent.key().length()); + if (key >= value) { + ids.push_back(agent.key().id()); + break; + } + } + while (trie_.predictive_search(agent)) { + ids.push_back(agent.key().id()); + } + break; + } + case OpType::LessThan: { + while (trie_.predictive_search(agent)) { + auto key = std::string(agent.key().ptr(), agent.key().length()); + if (key >= value) { + break; + } + ids.push_back(agent.key().id()); + } + break; + } + case OpType::LessEqual: { + while (trie_.predictive_search(agent)) { + auto key = std::string(agent.key().ptr(), agent.key().length()); + if (key > value) { + break; + } + ids.push_back(agent.key().id()); + } + break; + } + default: + throw SegcoreError( + OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", static_cast(op))); + } + + for (const auto str_id : ids) { + auto offsets = str_ids_to_offsets_[str_id]; + for (auto offset : offsets) { bitset[offset] = true; } } @@ -432,26 +465,38 @@ StringIndexMarisa::Range(std::string lower_bound_value, !(lb_inclusive && ub_inclusive))) { return bitset; } + + auto common_prefix = GetCommonPrefix(lower_bound_value, upper_bound_value); marisa::Agent agent; - for (size_t offset = 0; offset < count; ++offset) { - agent.set_query(str_ids_[offset]); - trie_.reverse_lookup(agent); - std::string raw_data(agent.key().ptr(), agent.key().length()); - bool set = true; - if (lb_inclusive) { - set &= raw_data.compare(lower_bound_value) >= 0; - } else { - set &= raw_data.compare(lower_bound_value) > 0; + agent.set_query(common_prefix.c_str()); + std::vector ids; + while (trie_.predictive_search(agent)) { + std::string_view val = + std::string_view(agent.key().ptr(), agent.key().length()); + if (val > upper_bound_value || + (!ub_inclusive && val == upper_bound_value)) { + break; } - if (ub_inclusive) { - set &= raw_data.compare(upper_bound_value) <= 0; - } else { - set &= raw_data.compare(upper_bound_value) < 0; + + if (val < lower_bound_value || + (!lb_inclusive && val == lower_bound_value)) { + continue; } - if (set) { + + if (((lb_inclusive && lower_bound_value <= val) || + (!lb_inclusive && lower_bound_value < val)) && + ((ub_inclusive && val <= upper_bound_value) || + (!ub_inclusive && val < upper_bound_value))) { + ids.push_back(agent.key().id()); + } + } + for (const auto str_id : ids) { + auto offsets = str_ids_to_offsets_[str_id]; + for (auto offset : offsets) { bitset[offset] = true; } } + return bitset; } diff --git a/internal/core/unittest/test_utils.cpp b/internal/core/unittest/test_utils.cpp index a92facde32..b0859c8d5b 100644 --- a/internal/core/unittest/test_utils.cpp +++ b/internal/core/unittest/test_utils.cpp @@ -190,3 +190,20 @@ TEST(Util, read_from_fd) { tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX), milvus::SegcoreError); } + +TEST(Util, get_common_prefix) { + std::string str1 = ""; + std::string str2 = "milvus"; + auto common_prefix = milvus::GetCommonPrefix(str1, str2); + EXPECT_STREQ(common_prefix.c_str(), ""); + + str1 = "milvus"; + str2 = "milvus is great"; + common_prefix = milvus::GetCommonPrefix(str1, str2); + EXPECT_STREQ(common_prefix.c_str(), "milvus"); + + str1 = "milvus"; + str2 = ""; + common_prefix = milvus::GetCommonPrefix(str1, str2); + EXPECT_STREQ(common_prefix.c_str(), ""); +}