diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 173fa94a24..f4809fbc29 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -33,6 +33,7 @@ #include "index/VectorDiskIndex.h" #include "index/ScalarIndexSort.h" +#include "index/StringIndexSort.h" #include "index/StringIndexMarisa.h" #include "index/BoolIndex.h" #include "index/InvertedIndexTantivy.h" @@ -90,7 +91,7 @@ IndexFactory::CreatePrimitiveScalarIndex( return std::make_unique>( create_index_info.tantivy_index_version, file_manager_context); } - return CreateStringIndexMarisa(file_manager_context); + return CreateStringIndexSort(file_manager_context); #else ThrowInfo(Unsupported, "unsupported platform"); #endif diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index 67d103ec6b..594f9d48ad 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -54,7 +54,28 @@ StringIndexMarisa::StringIndexMarisa( int64_t StringIndexMarisa::Size() { - return trie_.size(); + return total_size_; +} + +int64_t +StringIndexMarisa::CalculateTotalSize() const { + int64_t size = 0; + + // Size of the trie structure + // marisa trie uses io_size() to get the serialized size + // which approximates the memory usage + size += trie_.io_size(); + + // Size of str_ids_ vector (main data structure) + size += str_ids_.size() * sizeof(int64_t); + + // Size of str_ids_to_offsets_ map data + for (const auto& [key, vec] : str_ids_to_offsets_) { + size += sizeof(size_t); // key + size += vec.size() * sizeof(size_t); // vector data + } + + return size; } bool @@ -113,6 +134,7 @@ StringIndexMarisa::BuildWithFieldData( fill_offsets(); built_ = true; + total_size_ = CalculateTotalSize(); } void @@ -138,6 +160,7 @@ StringIndexMarisa::Build(size_t n, fill_offsets(); built_ = true; + total_size_ = CalculateTotalSize(); } BinarySet @@ -222,6 +245,8 @@ StringIndexMarisa::LoadWithoutAssemble(const BinarySet& set, memcpy(str_ids_.data(), str_ids->data.get(), str_ids_len); fill_offsets(); + built_ = true; + total_size_ = CalculateTotalSize(); } void diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index 840e6cf0b2..18bf0545b5 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -127,6 +127,9 @@ class StringIndexMarisa : public StringIndex { LoadWithoutAssemble(const BinarySet& binary_set, const Config& config) override; + int64_t + CalculateTotalSize() const; + private: Config config_; marisa::Trie trie_; @@ -134,6 +137,7 @@ class StringIndexMarisa : public StringIndex { std::map> str_ids_to_offsets_; bool built_ = false; std::shared_ptr file_manager_; + int64_t total_size_ = 0; // Cached total size to avoid runtime calculation }; using StringIndexMarisaPtr = std::unique_ptr; diff --git a/internal/core/src/index/StringIndexSort.cpp b/internal/core/src/index/StringIndexSort.cpp new file mode 100644 index 0000000000..dcf149157d --- /dev/null +++ b/internal/core/src/index/StringIndexSort.cpp @@ -0,0 +1,1130 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "index/StringIndexSort.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "storage/FileWriter.h" +#include "common/CDataType.h" +#include "knowhere/log.h" +#include "index/Meta.h" +#include "common/Utils.h" +#include "common/Slice.h" +#include "common/Types.h" +#include "index/Utils.h" +#include "storage/ThreadPools.h" +#include "storage/Util.h" + +namespace milvus::index { + +StringIndexSortImpl::ParsedData +StringIndexSortImpl::ParseBinaryData(const uint8_t* data, size_t data_size) { + ParsedData result; + const uint8_t* ptr = data; + + // Verify magic code at the end first + uint64_t magic_at_end; + memcpy( + &magic_at_end, data + data_size - sizeof(uint64_t), sizeof(uint64_t)); + if (magic_at_end != StringIndexSort::MAGIC_CODE) { + ThrowInfo(DataFormatBroken, + fmt::format("Invalid magic code: expected 0x{:X}, got 0x{:X}", + StringIndexSort::MAGIC_CODE, + magic_at_end)); + } + + // Read unique count + memcpy(&result.unique_count, ptr, sizeof(uint32_t)); + ptr += sizeof(uint32_t); + + if (result.unique_count == 0) { + ThrowInfo(DataFormatBroken, "Unique count is 0"); + } + + // Read string offsets + result.string_offsets = reinterpret_cast(ptr); + ptr += result.unique_count * sizeof(uint32_t); + + result.string_data_start = ptr; + + // Calculate total string section size + auto total_string_size = 0; + const uint8_t* last_str_ptr = + result.string_data_start + + result.string_offsets[result.unique_count - 1]; + uint32_t last_str_len; + memcpy(&last_str_len, last_str_ptr, sizeof(uint32_t)); + total_string_size = result.string_offsets[result.unique_count - 1] + + sizeof(uint32_t) + last_str_len; + + // Skip past string section to posting list offsets + ptr = result.string_data_start + total_string_size; + result.post_list_offsets = reinterpret_cast(ptr); + ptr += result.unique_count * sizeof(uint32_t); + + result.post_list_data_start = ptr; + + return result; +} + +const std::string STRING_INDEX_SORT_FILE = "string_index_sort"; + +StringIndexSort::StringIndexSort( + const storage::FileManagerContext& file_manager_context) + : StringIndex(ASCENDING_SORT), is_built_(false) { + if (file_manager_context.Valid()) { + field_id_ = file_manager_context.fieldDataMeta.field_id; + file_manager_ = + std::make_shared(file_manager_context); + } +} + +StringIndexSort::~StringIndexSort() = default; + +int64_t +StringIndexSort::Count() { + return total_num_rows_; +} + +void +StringIndexSort::Build(size_t n, + const std::string* values, + const bool* valid_data) { + if (is_built_) + return; + if (n == 0) { + ThrowInfo(DataIsEmpty, "StringIndexSort cannot build null values!"); + } + + index_build_begin_ = std::chrono::system_clock::now(); + total_num_rows_ = n; + valid_bitset_ = TargetBitmap(total_num_rows_, false); + idx_to_offsets_.clear(); + + // Create MemoryImpl and delegate building to it + impl_ = std::make_unique(); + auto* memory_impl = static_cast(impl_.get()); + + // Let MemoryImpl handle the building process + memory_impl->BuildFromRawData( + n, values, valid_data, valid_bitset_, idx_to_offsets_); + + is_built_ = true; + total_size_ = CalculateTotalSize(); +} + +void +StringIndexSort::Build(const Config& config) { + if (is_built_) { + return; + } + config_ = config; + auto field_datas = + storage::CacheRawDataAndFillMissing(file_manager_, config); + BuildWithFieldData(field_datas); +} + +void +StringIndexSort::BuildWithFieldData( + const std::vector& field_datas) { + if (is_built_) + return; + + index_build_begin_ = std::chrono::system_clock::now(); + + // Calculate total number of rows + total_num_rows_ = 0; + for (const auto& data : field_datas) { + total_num_rows_ += data->get_num_rows(); + } + + if (total_num_rows_ == 0) { + ThrowInfo(DataIsEmpty, "StringIndexSort cannot build null values!"); + } + + // Initialize structures + valid_bitset_ = TargetBitmap(total_num_rows_, false); + idx_to_offsets_.clear(); + + // Create MemoryImpl and build directly from field data + impl_ = std::make_unique(); + static_cast(impl_.get()) + ->BuildFromFieldData( + field_datas, total_num_rows_, valid_bitset_, idx_to_offsets_); + + is_built_ = true; + total_size_ = CalculateTotalSize(); +} + +BinarySet +StringIndexSort::Serialize(const Config& config) { + AssertInfo(is_built_, "index has not been built"); + AssertInfo(impl_ != nullptr, "impl_ is null, cannot serialize"); + + BinarySet res_set; + + std::shared_ptr version_buf(new uint8_t[sizeof(uint32_t)]); + uint32_t version = SERIALIZATION_VERSION; + memcpy(version_buf.get(), &version, sizeof(uint32_t)); + res_set.Append("version", version_buf, sizeof(uint32_t)); + + // Use MemoryImpl to serialize + auto* memory_impl = static_cast(impl_.get()); + size_t total_size = memory_impl->GetSerializedSize(); + + std::shared_ptr data_buffer(new uint8_t[total_size]); + size_t offset = 0; + memory_impl->SerializeToBinary(data_buffer.get(), offset); + + res_set.Append("index_data", data_buffer, total_size); + + // Serialize total number of rows (use same key as ScalarIndexSort) + std::shared_ptr index_num_rows(new uint8_t[sizeof(size_t)]); + memcpy(index_num_rows.get(), &total_num_rows_, sizeof(size_t)); + res_set.Append("index_num_rows", index_num_rows, sizeof(size_t)); + + // Serialize valid_bitset + size_t valid_bitset_size = + (total_num_rows_ + 7) / 8; // Round up to byte boundary + std::shared_ptr valid_bitset_data( + new uint8_t[valid_bitset_size]); + memset(valid_bitset_data.get(), 0, valid_bitset_size); + for (size_t i = 0; i < total_num_rows_; ++i) { + if (valid_bitset_[i]) { + valid_bitset_data[i / 8] |= (1 << (i % 8)); + } + } + res_set.Append("valid_bitset", valid_bitset_data, valid_bitset_size); + + milvus::Disassemble(res_set); + return res_set; +} + +IndexStatsPtr +StringIndexSort::Upload(const Config& config) { + auto index_build_duration = + std::chrono::duration_cast( + std::chrono::system_clock::now() - index_build_begin_) + .count(); + LOG_INFO( + "index build done for StringIndexSort, field_id: {}, duration: {}ms", + field_id_, + index_build_duration); + + auto binary_set = Serialize(config); + file_manager_->AddFile(binary_set); + + auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); + return IndexStats::NewFromSizeMap(file_manager_->GetAddedTotalMemSize(), + remote_paths_to_size); +} + +void +StringIndexSort::Load(const BinarySet& index_binary, const Config& config) { + milvus::Assemble(const_cast(index_binary)); + LoadWithoutAssemble(index_binary, config); +} + +void +StringIndexSort::Load(milvus::tracer::TraceContext ctx, const Config& config) { + auto index_files = + GetValueFromConfig>(config, "index_files"); + AssertInfo(index_files.has_value() && !index_files.value().empty(), + "index_files not found in config"); + + auto load_priority = + GetValueFromConfig( + config, milvus::LOAD_PRIORITY) + .value_or(milvus::proto::common::LoadPriority::HIGH); + + auto index_datas = + file_manager_->LoadIndexToMemory(index_files.value(), load_priority); + + BinarySet binary_set; + AssembleIndexDatas(index_datas, binary_set); + + index_datas.clear(); + + LoadWithoutAssemble(binary_set, config); +} + +void +StringIndexSort::LoadWithoutAssemble(const BinarySet& binary_set, + const Config& config) { + config_ = config; + + auto index_num_rows = binary_set.GetByName("index_num_rows"); + AssertInfo(index_num_rows != nullptr, + "Failed to find 'index_num_rows' in binary_set"); + memcpy(&total_num_rows_, index_num_rows->data.get(), sizeof(size_t)); + + // Initialize idx_to_offsets - it will be rebuilt by LoadFromBinary + idx_to_offsets_.resize(total_num_rows_); + + auto valid_bitset_data = binary_set.GetByName("valid_bitset"); + AssertInfo(valid_bitset_data != nullptr, + "Failed to find 'valid_bitset' in binary_set"); + valid_bitset_ = TargetBitmap(total_num_rows_, false); + for (size_t i = 0; i < total_num_rows_; ++i) { + uint8_t byte = valid_bitset_data->data[i / 8]; + if (byte & (1 << (i % 8))) { + valid_bitset_.set(i); + } + } + + auto version_data = binary_set.GetByName("version"); + AssertInfo(version_data != nullptr, + "Failed to find 'version' in binary_set"); + + uint32_t version; + memcpy(&version, version_data->data.get(), sizeof(uint32_t)); + + if (version != SERIALIZATION_VERSION) { + ThrowInfo(milvus::ErrorCode::Unsupported, + fmt::format("Unsupported StringIndexSort serialization " + "version: {}, expected: {}", + version, + SERIALIZATION_VERSION)); + } + + // Check if mmap is enabled + if (config.contains(MMAP_FILE_PATH)) { + LOG_INFO("StringIndexSort: loading with mmap strategy"); + auto mmap_impl = std::make_unique(); + + auto mmap_path = + GetValueFromConfig(config, MMAP_FILE_PATH).value(); + mmap_impl->SetMmapFilePath(mmap_path); + mmap_impl->LoadFromBinary( + binary_set, total_num_rows_, valid_bitset_, idx_to_offsets_); + impl_ = std::move(mmap_impl); + } else { + LOG_INFO("StringIndexSort: loading with memory strategy"); + impl_ = std::make_unique(); + impl_->LoadFromBinary( + binary_set, total_num_rows_, valid_bitset_, idx_to_offsets_); + } + + is_built_ = true; + total_size_ = CalculateTotalSize(); +} + +const TargetBitmap +StringIndexSort::In(size_t n, const std::string* values) { + assert(impl_ != nullptr); + return impl_->In(n, values, total_num_rows_); +} + +const TargetBitmap +StringIndexSort::NotIn(size_t n, const std::string* values) { + assert(impl_ != nullptr); + return impl_->NotIn(n, values, total_num_rows_, valid_bitset_); +} + +const TargetBitmap +StringIndexSort::IsNull() { + assert(impl_ != nullptr); + return impl_->IsNull(total_num_rows_, valid_bitset_); +} + +TargetBitmap +StringIndexSort::IsNotNull() { + assert(impl_ != nullptr); + return impl_->IsNotNull(valid_bitset_); +} + +const TargetBitmap +StringIndexSort::Range(std::string value, OpType op) { + assert(impl_ != nullptr); + return impl_->Range(value, op, total_num_rows_); +} + +const TargetBitmap +StringIndexSort::Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive) { + assert(impl_ != nullptr); + return impl_->Range(lower_bound_value, + lb_inclusive, + upper_bound_value, + ub_inclusive, + total_num_rows_); +} + +const TargetBitmap +StringIndexSort::PrefixMatch(const std::string_view prefix) { + assert(impl_ != nullptr); + return impl_->PrefixMatch(prefix, total_num_rows_); +} + +std::optional +StringIndexSort::Reverse_Lookup(size_t offset) const { + assert(impl_ != nullptr); + return impl_->Reverse_Lookup( + offset, total_num_rows_, valid_bitset_, idx_to_offsets_); +} + +int64_t +StringIndexSort::Size() { + return total_size_; +} + +int64_t +StringIndexSort::CalculateTotalSize() const { + int64_t size = 0; + + size += impl_->Size(); + // Add common structures (always present) + size += idx_to_offsets_.size() * sizeof(int32_t); + size += valid_bitset_.size() / 8; + + // Add object overhead + size += sizeof(*this); + + return size; +} + +void +StringIndexSortMemoryImpl::BuildFromMap( + std::map&& map, + size_t total_num_rows, + std::vector& idx_to_offsets) { + unique_values_.clear(); + posting_lists_.clear(); + unique_values_.reserve(map.size()); + posting_lists_.reserve(map.size()); + + // Initialize idx_to_offsets + idx_to_offsets.resize(total_num_rows); + std::fill(idx_to_offsets.begin(), idx_to_offsets.end(), -1); + + // Convert map to vectors (map is already sorted) + size_t unique_idx = 0; + for (auto& [value, posting_list] : map) { + // Map each row_id to its unique value index + for (uint32_t row_id : posting_list) { + idx_to_offsets[row_id] = unique_idx; + } + unique_values_.push_back(std::move(value)); + posting_lists_.push_back(std::move(posting_list)); + unique_idx++; + } +} + +void +StringIndexSortMemoryImpl::BuildFromRawData( + size_t n, + const std::string* values, + const bool* valid_data, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) { + // Use map to collect unique values and their posting lists + std::map map; + + for (size_t i = 0; i < n; ++i) { + if (!valid_data || valid_data[i]) { + map[values[i]].push_back(static_cast(i)); + valid_bitset.set(i); + } + } + + BuildFromMap(std::move(map), n, idx_to_offsets); +} + +void +StringIndexSortMemoryImpl::BuildFromFieldData( + const std::vector& field_datas, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) { + // Use map to collect unique values and their posting lists + // std::map is sorted + std::map map; + + size_t row_id = 0; + for (const auto& field_data : field_datas) { + auto slice_num = field_data->get_num_rows(); + for (size_t i = 0; i < slice_num; ++i) { + if (field_data->is_valid(i)) { + auto value = reinterpret_cast( + field_data->RawValue(i)); + map[*value].push_back(static_cast(row_id)); + valid_bitset.set(row_id); + } + row_id++; + } + } + + BuildFromMap(std::move(map), total_num_rows, idx_to_offsets); +} + +size_t +StringIndexSortMemoryImpl::GetSerializedSize() const { + size_t total_size = sizeof(uint32_t); // unique_count + + // String offsets array + total_size += unique_values_.size() * sizeof(uint32_t); + + // String data section + for (size_t i = 0; i < unique_values_.size(); ++i) { + total_size += sizeof(uint32_t); // str_len + total_size += unique_values_[i].size(); // string content + } + + // Posting list offsets array + total_size += posting_lists_.size() * sizeof(uint32_t); + + // Posting list data section + for (size_t i = 0; i < posting_lists_.size(); ++i) { + total_size += sizeof(uint32_t); // post_list_len + total_size += posting_lists_[i].size() * sizeof(uint32_t); // row_ids + } + + // Magic code at end + total_size += sizeof(uint64_t); + + return total_size; +} + +void +StringIndexSortMemoryImpl::SerializeToBinary(uint8_t* ptr, + size_t& offset) const { + size_t start_offset = offset; + + // Write unique count as uint32_t + uint32_t unique_count = static_cast(unique_values_.size()); + memcpy(ptr + offset, &unique_count, sizeof(uint32_t)); + offset += sizeof(uint32_t); + + // Calculate and write string offsets + size_t string_offsets_start = offset; + offset += unique_count * sizeof(uint32_t); // Reserve space for offsets + + size_t string_data_start = offset; + std::vector string_offsets; + string_offsets.reserve(unique_count); + + // Write string data section + for (size_t i = 0; i < unique_values_.size(); ++i) { + string_offsets.push_back( + static_cast(offset - string_data_start)); + + // Write string length and content + uint32_t str_len = static_cast(unique_values_[i].size()); + memcpy(ptr + offset, &str_len, sizeof(uint32_t)); + offset += sizeof(uint32_t); + memcpy(ptr + offset, unique_values_[i].data(), str_len); + offset += str_len; + } + + // Write string offsets back + memcpy(ptr + string_offsets_start, + string_offsets.data(), + string_offsets.size() * sizeof(uint32_t)); + + // Calculate and write posting list offsets + size_t post_list_offsets_start = offset; + offset += unique_count * sizeof(uint32_t); // Reserve space for offsets + + size_t post_list_data_start = offset; + std::vector post_list_offsets; + post_list_offsets.reserve(unique_count); + + // Write posting list data section + for (size_t i = 0; i < posting_lists_.size(); ++i) { + post_list_offsets.push_back( + static_cast(offset - post_list_data_start)); + + // Write posting list length and content + uint32_t post_list_len = + static_cast(posting_lists_[i].size()); + memcpy(ptr + offset, &post_list_len, sizeof(uint32_t)); + offset += sizeof(uint32_t); + + for (uint32_t row_id : posting_lists_[i]) { + memcpy(ptr + offset, &row_id, sizeof(uint32_t)); + offset += sizeof(uint32_t); + } + } + + // Write posting list offsets back + memcpy(ptr + post_list_offsets_start, + post_list_offsets.data(), + post_list_offsets.size() * sizeof(uint32_t)); + + // Write magic code at the very end + uint64_t magic = StringIndexSort::MAGIC_CODE; + memcpy(ptr + offset, &magic, sizeof(uint64_t)); + offset += sizeof(uint64_t); +} + +void +StringIndexSortMemoryImpl::LoadFromBinary( + const BinarySet& binary_set, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) { + auto index_data = binary_set.GetByName("index_data"); + AssertInfo(index_data != nullptr, + "Failed to find 'index_data' in binary_set"); + + auto parsed = ParseBinaryData(index_data->data.get(), index_data->size); + unique_values_.clear(); + posting_lists_.clear(); + unique_values_.reserve(parsed.unique_count); + posting_lists_.reserve(parsed.unique_count); + + std::fill(idx_to_offsets.begin(), idx_to_offsets.end(), -1); + + // Read strings and posting lists + for (uint32_t unique_idx = 0; unique_idx < parsed.unique_count; + ++unique_idx) { + // Read string + const uint8_t* str_ptr = + parsed.string_data_start + parsed.string_offsets[unique_idx]; + uint32_t str_len; + memcpy(&str_len, str_ptr, sizeof(uint32_t)); + str_ptr += sizeof(uint32_t); + std::string value(reinterpret_cast(str_ptr), str_len); + + // Read posting list + const uint8_t* post_list_ptr = + parsed.post_list_data_start + parsed.post_list_offsets[unique_idx]; + uint32_t post_list_len; + memcpy(&post_list_len, post_list_ptr, sizeof(uint32_t)); + post_list_ptr += sizeof(uint32_t); + + PostingList posting_list; + posting_list.reserve(post_list_len); + const uint32_t* row_ids = + reinterpret_cast(post_list_ptr); + + for (uint32_t j = 0; j < post_list_len; ++j) { + uint32_t row_id = row_ids[j]; + posting_list.push_back(row_id); + + // Map each row_id to its unique value index + if (static_cast(row_id) >= idx_to_offsets.size()) { + ThrowInfo( + milvus::ErrorCode::UnexpectedError, + fmt::format("row_id {} exceeds idx_to_offsets size {}", + row_id, + idx_to_offsets.size())); + } + idx_to_offsets[row_id] = unique_idx; + } + + unique_values_.push_back(std::move(value)); + posting_lists_.push_back(std::move(posting_list)); + } +} + +size_t +StringIndexSortMemoryImpl::FindValueIndex(const std::string& value) const { + auto it = + std::lower_bound(unique_values_.begin(), unique_values_.end(), value); + if (it != unique_values_.end() && *it == value) { + return std::distance(unique_values_.begin(), it); + } + return std::numeric_limits::max(); +} + +const TargetBitmap +StringIndexSortMemoryImpl::In(size_t n, + const std::string* values, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + for (size_t i = 0; i < n; ++i) { + size_t idx = FindValueIndex(values[i]); + if (idx != std::numeric_limits::max()) { + const auto& posting_list = posting_lists_[idx]; + for (uint32_t row_id : posting_list) { + bitset[row_id] = true; + } + } + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMemoryImpl::NotIn(size_t n, + const std::string* values, + size_t total_num_rows, + const TargetBitmap& valid_bitset) { + auto in_bitset = In(n, values, total_num_rows); + in_bitset.flip(); + + // Reset null values + for (size_t i = 0; i < total_num_rows; ++i) { + if (!valid_bitset[i]) { + in_bitset.reset(i); + } + } + + return in_bitset; +} + +const TargetBitmap +StringIndexSortMemoryImpl::IsNull(size_t total_num_rows, + const TargetBitmap& valid_bitset) { + auto result = valid_bitset.clone(); + result.flip(); + return result; +} + +TargetBitmap +StringIndexSortMemoryImpl::IsNotNull(const TargetBitmap& valid_bitset) { + return valid_bitset.clone(); +} + +const TargetBitmap +StringIndexSortMemoryImpl::Range(std::string value, + OpType op, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + size_t start_idx = 0; + size_t end_idx = unique_values_.size(); + + switch (op) { + case OpType::GreaterThan: { + auto it = std::upper_bound( + unique_values_.begin(), unique_values_.end(), value); + start_idx = std::distance(unique_values_.begin(), it); + break; + } + case OpType::GreaterEqual: { + auto it = std::lower_bound( + unique_values_.begin(), unique_values_.end(), value); + start_idx = std::distance(unique_values_.begin(), it); + break; + } + case OpType::LessThan: { + auto it = std::lower_bound( + unique_values_.begin(), unique_values_.end(), value); + end_idx = std::distance(unique_values_.begin(), it); + break; + } + case OpType::LessEqual: { + auto it = std::upper_bound( + unique_values_.begin(), unique_values_.end(), value); + end_idx = std::distance(unique_values_.begin(), it); + break; + } + default: + ThrowInfo( + milvus::OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", static_cast(op))); + } + + // Set bits for all posting lists in range + for (size_t i = start_idx; i < end_idx; ++i) { + const auto& posting_list = posting_lists_[i]; + for (uint32_t row_id : posting_list) { + bitset[row_id] = true; + } + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMemoryImpl::Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + auto start_it = lb_inclusive ? std::lower_bound(unique_values_.begin(), + unique_values_.end(), + lower_bound_value) + : std::upper_bound(unique_values_.begin(), + unique_values_.end(), + lower_bound_value); + + auto end_it = ub_inclusive ? std::upper_bound(unique_values_.begin(), + unique_values_.end(), + upper_bound_value) + : std::lower_bound(unique_values_.begin(), + unique_values_.end(), + upper_bound_value); + + size_t start_idx = std::distance(unique_values_.begin(), start_it); + size_t end_idx = std::distance(unique_values_.begin(), end_it); + + for (size_t i = start_idx; i < end_idx; ++i) { + const auto& posting_list = posting_lists_[i]; + for (uint32_t row_id : posting_list) { + bitset[row_id] = true; + } + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMemoryImpl::PrefixMatch(const std::string_view prefix, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + auto it = std::lower_bound( + unique_values_.begin(), unique_values_.end(), std::string(prefix)); + + size_t idx = std::distance(unique_values_.begin(), it); + + while (idx < unique_values_.size()) { + if (!milvus::PrefixMatch(unique_values_[idx], prefix)) { + break; + } + const auto& posting_list = posting_lists_[idx]; + for (uint32_t row_id : posting_list) { + bitset[row_id] = true; + } + ++idx; + } + + return bitset; +} + +std::optional +StringIndexSortMemoryImpl::Reverse_Lookup( + size_t offset, + size_t total_num_rows, + const TargetBitmap& valid_bitset, + const std::vector& idx_to_offsets) const { + if (offset >= total_num_rows || !valid_bitset[offset]) { + return std::nullopt; + } + + if (offset < idx_to_offsets.size()) { + size_t unique_idx = idx_to_offsets[offset]; + if (unique_idx < unique_values_.size()) { + return unique_values_[unique_idx]; + } + } + + return std::nullopt; +} + +int64_t +StringIndexSortMemoryImpl::Size() { + size_t size = 0; + + // Size of unique values (actual string data) + for (const auto& str : unique_values_) { + size += str.size(); + } + + // Size of posting lists (actual ID data) + for (const auto& list : posting_lists_) { + size += list.size() * sizeof(uint32_t); + } + + return size; +} + +StringIndexSortMmapImpl::~StringIndexSortMmapImpl() { + if (mmap_data_ != nullptr && mmap_data_ != MAP_FAILED) { + munmap(mmap_data_, mmap_size_); + if (!mmap_filepath_.empty()) { + unlink(mmap_filepath_.c_str()); + } + } +} + +void +StringIndexSortMmapImpl::LoadFromBinary(const BinarySet& binary_set, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) { + auto index_data = binary_set.GetByName("index_data"); + + AssertInfo(!mmap_filepath_.empty(), "mmap filepath is not set"); + + std::filesystem::create_directories( + std::filesystem::path(mmap_filepath_).parent_path()); + + { + auto file_writer = storage::FileWriter(mmap_filepath_); + file_writer.Write(index_data->data.get(), index_data->size); + file_writer.Finish(); + } + + auto fd = open(mmap_filepath_.c_str(), O_RDONLY); + if (fd == -1) { + ThrowInfo(DataFormatBroken, "Failed to open mmap file"); + } + + mmap_size_ = index_data->size; + mmap_data_ = static_cast( + mmap(nullptr, mmap_size_, PROT_READ, MAP_PRIVATE, fd, 0)); + close(fd); + + if (mmap_data_ == MAP_FAILED) { + ThrowInfo(DataFormatBroken, "Failed to mmap file"); + } + + const uint8_t* data_start = reinterpret_cast(mmap_data_); + + auto parsed = ParseBinaryData(data_start, mmap_size_); + unique_count_ = parsed.unique_count; + string_offsets_ = parsed.string_offsets; + string_data_start_ = parsed.string_data_start; + post_list_offsets_ = parsed.post_list_offsets; + post_list_data_start_ = parsed.post_list_data_start; + + // Initialize idx_to_offsets + std::fill(idx_to_offsets.begin(), idx_to_offsets.end(), -1); + // Rebuild idx_to_offsets by iterating through entries + for (uint32_t unique_idx = 0; unique_idx < unique_count_; ++unique_idx) { + MmapEntry entry = GetEntry(unique_idx); + + // Map each row_id in posting list to this unique index + entry.for_each_row_id([&idx_to_offsets, unique_idx](uint32_t row_id) { + idx_to_offsets[row_id] = unique_idx; + }); + } +} + +size_t +StringIndexSortMmapImpl::FindValueIndex(const std::string& value) const { + std::string_view search_value(value); + size_t left = 0; + size_t right = unique_count_; + + while (left < right) { + size_t mid = left + (right - left) / 2; + MmapEntry entry = GetEntry(mid); + std::string_view entry_sv = entry.get_string_view(); + + int cmp = entry_sv.compare(search_value); + if (cmp < 0) { + left = mid + 1; + } else if (cmp > 0) { + right = mid; + } else { + return mid; + } + } + + return unique_count_; +} + +size_t +StringIndexSortMmapImpl::LowerBound(const std::string_view& value) const { + size_t left = 0, right = unique_count_; + while (left < right) { + size_t mid = left + (right - left) / 2; + if (GetEntry(mid).get_string_view() < value) { + left = mid + 1; + } else { + right = mid; + } + } + return left; +} + +size_t +StringIndexSortMmapImpl::UpperBound(const std::string_view& value) const { + size_t left = 0, right = unique_count_; + while (left < right) { + size_t mid = left + (right - left) / 2; + if (GetEntry(mid).get_string_view() <= value) { + left = mid + 1; + } else { + right = mid; + } + } + return left; +} + +const TargetBitmap +StringIndexSortMmapImpl::In(size_t n, + const std::string* values, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + for (size_t i = 0; i < n; ++i) { + size_t idx = FindValueIndex(values[i]); + if (idx < unique_count_) { + MmapEntry entry = GetEntry(idx); + // Set bits for all row_ids in posting list + entry.for_each_row_id( + [&bitset](uint32_t row_id) { bitset.set(row_id); }); + } + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMmapImpl::NotIn(size_t n, + const std::string* values, + size_t total_num_rows, + const TargetBitmap& valid_bitset) { + auto in_bitset = In(n, values, total_num_rows); + in_bitset.flip(); + + for (size_t i = 0; i < total_num_rows; ++i) { + if (!valid_bitset[i]) { + in_bitset.reset(i); + } + } + + return in_bitset; +} + +const TargetBitmap +StringIndexSortMmapImpl::IsNull(size_t total_num_rows, + const TargetBitmap& valid_bitset) { + auto null_bitset = valid_bitset.clone(); + null_bitset.flip(); + return null_bitset; +} + +TargetBitmap +StringIndexSortMmapImpl::IsNotNull(const TargetBitmap& valid_bitset) { + return valid_bitset.clone(); +} + +const TargetBitmap +StringIndexSortMmapImpl::Range(std::string value, + OpType op, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + size_t start_idx = 0; + size_t end_idx = unique_count_; + + switch (op) { + case OpType::GreaterThan: + start_idx = UpperBound(value); + break; + case OpType::GreaterEqual: + start_idx = LowerBound(value); + break; + case OpType::LessThan: + end_idx = LowerBound(value); + break; + case OpType::LessEqual: + end_idx = UpperBound(value); + break; + default: + ThrowInfo( + OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", static_cast(op))); + } + + // Set bits for all posting lists in range + for (size_t i = start_idx; i < end_idx; ++i) { + MmapEntry entry = GetEntry(i); + entry.for_each_row_id( + [&bitset](uint32_t row_id) { bitset.set(row_id); }); + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMmapImpl::Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + size_t start_idx = lb_inclusive ? LowerBound(lower_bound_value) + : UpperBound(lower_bound_value); + size_t end_idx = ub_inclusive ? UpperBound(upper_bound_value) + : LowerBound(upper_bound_value); + + // Set bits for all posting lists in range + for (size_t i = start_idx; i < end_idx; ++i) { + MmapEntry entry = GetEntry(i); + entry.for_each_row_id( + [&bitset](uint32_t row_id) { bitset.set(row_id); }); + } + + return bitset; +} + +const TargetBitmap +StringIndexSortMmapImpl::PrefixMatch(const std::string_view prefix, + size_t total_num_rows) { + TargetBitmap bitset(total_num_rows, false); + + // Find the first string that is >= prefix + size_t idx = LowerBound(prefix); + + while (idx < unique_count_) { + MmapEntry entry = GetEntry(idx); + std::string_view entry_sv = entry.get_string_view(); + + if (entry_sv.size() < prefix.size() || + entry_sv.substr(0, prefix.size()) != prefix) { + break; + } + + // Add all row_ids for this matching string + entry.for_each_row_id( + [&bitset](uint32_t row_id) { bitset.set(row_id); }); + + ++idx; + } + + return bitset; +} + +std::optional +StringIndexSortMmapImpl::Reverse_Lookup( + size_t offset, + size_t total_num_rows, + const TargetBitmap& valid_bitset, + const std::vector& idx_to_offsets) const { + if (offset >= total_num_rows || !valid_bitset[offset]) { + return std::nullopt; + } + + if (offset < idx_to_offsets.size()) { + int32_t unique_idx = idx_to_offsets[offset]; + if (unique_idx >= 0 && + static_cast(unique_idx) < unique_count_) { + MmapEntry entry = GetEntry(unique_idx); + // Convert string_view to string for return + std::string_view sv = entry.get_string_view(); + return std::string(sv); + } + } + + return std::nullopt; +} + +int64_t +StringIndexSortMmapImpl::Size() { + return mmap_size_; +} + +} // namespace milvus::index diff --git a/internal/core/src/index/StringIndexSort.h b/internal/core/src/index/StringIndexSort.h new file mode 100644 index 0000000000..ddb09c61be --- /dev/null +++ b/internal/core/src/index/StringIndexSort.h @@ -0,0 +1,431 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "index/StringIndex.h" +#include "storage/MemFileManagerImpl.h" +#include "storage/DiskFileManagerImpl.h" +#include "storage/FileWriter.h" +#include "common/File.h" + +namespace milvus::index { + +// Forward declaration +class StringIndexSortImpl; + +// Main StringIndexSort class using pImpl pattern +class StringIndexSort : public StringIndex { + public: + static constexpr uint32_t SERIALIZATION_VERSION = 1; + static constexpr uint64_t MAGIC_CODE = + 0x5354524E47534F52; // "STRNGSOR" in hex + + explicit StringIndexSort( + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + + virtual ~StringIndexSort(); + + int64_t + Count() override; + + ScalarIndexType + GetIndexType() const override { + return ScalarIndexType::STLSORT; + } + + const bool + HasRawData() const override { + return true; + } + + void + Build(size_t n, + const std::string* values, + const bool* valid_data = nullptr) override; + + void + Build(const Config& config = {}) override; + + void + BuildWithFieldData(const std::vector& datas) override; + + // See detailed format in StringIndexSortMemoryImpl::SerializeToBinary + BinarySet + Serialize(const Config& config) override; + + IndexStatsPtr + Upload(const Config& config = {}) override; + + void + Load(const BinarySet& index_binary, const Config& config = {}) override; + + void + Load(milvus::tracer::TraceContext ctx, const Config& config = {}) override; + + void + LoadWithoutAssemble(const BinarySet& binary_set, + const Config& config) override; + + // Query methods - delegated to impl + const TargetBitmap + In(size_t n, const std::string* values) override; + + const TargetBitmap + NotIn(size_t n, const std::string* values) override; + + const TargetBitmap + IsNull() override; + + TargetBitmap + IsNotNull() override; + + const TargetBitmap + Range(std::string value, OpType op) override; + + const TargetBitmap + Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive) override; + + const TargetBitmap + PrefixMatch(const std::string_view prefix) override; + + std::optional + Reverse_Lookup(size_t offset) const override; + + int64_t + Size() override; + + protected: + int64_t + CalculateTotalSize() const; + + // Common fields + int64_t field_id_ = 0; + bool is_built_ = false; + Config config_; + std::shared_ptr file_manager_; + size_t total_num_rows_{0}; + TargetBitmap valid_bitset_; + std::vector idx_to_offsets_; + std::chrono::time_point index_build_begin_; + + int64_t total_size_{0}; + std::unique_ptr impl_; +}; + +// Abstract interface for implementations +class StringIndexSortImpl { + public: + virtual ~StringIndexSortImpl() = default; + + virtual void + LoadFromBinary(const BinarySet& binary_set, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) = 0; + + struct ParsedData { + uint32_t unique_count; + const uint32_t* string_offsets; + const uint8_t* string_data_start; + const uint32_t* post_list_offsets; + const uint8_t* post_list_data_start; + }; + + static ParsedData + ParseBinaryData(const uint8_t* data, size_t data_size); + + virtual const TargetBitmap + In(size_t n, const std::string* values, size_t total_num_rows) = 0; + + virtual const TargetBitmap + NotIn(size_t n, + const std::string* values, + size_t total_num_rows, + const TargetBitmap& valid_bitset) = 0; + + virtual const TargetBitmap + IsNull(size_t total_num_rows, const TargetBitmap& valid_bitset) = 0; + + virtual TargetBitmap + IsNotNull(const TargetBitmap& valid_bitset) = 0; + + virtual const TargetBitmap + Range(std::string value, OpType op, size_t total_num_rows) = 0; + + virtual const TargetBitmap + Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive, + size_t total_num_rows) = 0; + + virtual const TargetBitmap + PrefixMatch(const std::string_view prefix, size_t total_num_rows) = 0; + + virtual std::optional + Reverse_Lookup(size_t offset, + size_t total_num_rows, + const TargetBitmap& valid_bitset, + const std::vector& idx_to_offsets) const = 0; + + virtual int64_t + Size() = 0; +}; + +class StringIndexSortMemoryImpl : public StringIndexSortImpl { + public: + using PostingList = folly::small_vector; + + void + BuildFromRawData(size_t n, + const std::string* values, + const bool* valid_data, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets); + + void + BuildFromFieldData(const std::vector& field_datas, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets); + + // Serialize to binary format + // The binary format is : [unique_count][string_offsets][string_data][post_list_offsets][post_list_data][magic_code] + // string_offsets: array of offsets into string_data section + // string_data: str_len1, str1, str_len2, str2, ... + // post_list_offsets: array of offsets into post_list_data section + // post_list_data: post_list_len1, row_id1, row_id2, ..., post_list_len2, row_id1, row_id2, ... + void + SerializeToBinary(uint8_t* ptr, size_t& offset) const; + + size_t + GetSerializedSize() const; + + void + LoadFromBinary(const BinarySet& binary_set, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) override; + + const TargetBitmap + In(size_t n, const std::string* values, size_t total_num_rows) override; + + const TargetBitmap + NotIn(size_t n, + const std::string* values, + size_t total_num_rows, + const TargetBitmap& valid_bitset) override; + + const TargetBitmap + IsNull(size_t total_num_rows, const TargetBitmap& valid_bitset) override; + + TargetBitmap + IsNotNull(const TargetBitmap& valid_bitset) override; + + const TargetBitmap + Range(std::string value, OpType op, size_t total_num_rows) override; + + const TargetBitmap + Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive, + size_t total_num_rows) override; + + const TargetBitmap + PrefixMatch(const std::string_view prefix, size_t total_num_rows) override; + + std::optional + Reverse_Lookup(size_t offset, + size_t total_num_rows, + const TargetBitmap& valid_bitset, + const std::vector& idx_to_offsets) const override; + + int64_t + Size() override; + + private: + // Helper method for binary search + size_t + FindValueIndex(const std::string& value) const; + + void + BuildFromMap(std::map&& unique_map, + size_t total_num_rows, + std::vector& idx_to_offsets); + + // Keep unique_values_ and posting_lists_ separated for cache efficiency + // Sorted unique values + std::vector unique_values_; + // Corresponding posting lists + std::vector posting_lists_; +}; + +class StringIndexSortMmapImpl : public StringIndexSortImpl { + public: + ~StringIndexSortMmapImpl(); + + // Helper struct to access separated string and posting list data + struct MmapEntry { + const char* str_data_ptr; // Pointer to string data + const uint32_t* post_list_data_ptr; // Pointer to posting list data + uint32_t str_len; // String length + uint32_t post_list_len; // Posting list length + + MmapEntry() = default; + + MmapEntry(const uint8_t* str_ptr, const uint8_t* post_list_ptr) { + // Read string length and data pointer + str_len = *reinterpret_cast(str_ptr); + str_data_ptr = + reinterpret_cast(str_ptr + sizeof(uint32_t)); + + // Read posting list length and data pointer + post_list_len = *reinterpret_cast(post_list_ptr); + post_list_data_ptr = reinterpret_cast( + post_list_ptr + sizeof(uint32_t)); + } + + std::string_view + get_string_view() const { + return std::string_view(str_data_ptr, str_len); + } + + size_t + get_posting_list_len() const { + return post_list_len; + } + + uint32_t + get_row_id(size_t idx) const { + return post_list_data_ptr[idx]; + } + + template + void + for_each_row_id(Func func) const { + for (uint32_t i = 0; i < post_list_len; ++i) { + func(post_list_data_ptr[i]); + } + } + }; + + void + LoadFromBinary(const BinarySet& binary_set, + size_t total_num_rows, + TargetBitmap& valid_bitset, + std::vector& idx_to_offsets) override; + + void + SetMmapFilePath(const std::string& filepath) { + mmap_filepath_ = filepath; + } + + const TargetBitmap + In(size_t n, const std::string* values, size_t total_num_rows) override; + + const TargetBitmap + NotIn(size_t n, + const std::string* values, + size_t total_num_rows, + const TargetBitmap& valid_bitset) override; + + const TargetBitmap + IsNull(size_t total_num_rows, const TargetBitmap& valid_bitset) override; + + TargetBitmap + IsNotNull(const TargetBitmap& valid_bitset) override; + + const TargetBitmap + Range(std::string value, OpType op, size_t total_num_rows) override; + + const TargetBitmap + Range(std::string lower_bound_value, + bool lb_inclusive, + std::string upper_bound_value, + bool ub_inclusive, + size_t total_num_rows) override; + + const TargetBitmap + PrefixMatch(const std::string_view prefix, size_t total_num_rows) override; + + std::optional + Reverse_Lookup(size_t offset, + size_t total_num_rows, + const TargetBitmap& valid_bitset, + const std::vector& idx_to_offsets) const override; + + int64_t + Size() override; + + private: + // Binary search for a value + size_t + FindValueIndex(const std::string& value) const; + + // Binary search helpers + size_t + LowerBound(const std::string_view& value) const; + + size_t + UpperBound(const std::string_view& value) const; + + MmapEntry + GetEntry(size_t idx) const { + const uint8_t* str_ptr = string_data_start_ + string_offsets_[idx]; + const uint8_t* post_list_ptr = + post_list_data_start_ + post_list_offsets_[idx]; + return MmapEntry(str_ptr, post_list_ptr); + } + + private: + char* mmap_data_ = nullptr; + size_t mmap_size_ = 0; + std::string mmap_filepath_; + size_t unique_count_ = 0; + + // Pointers to different sections in mmap'd data + const uint32_t* string_offsets_ = nullptr; + const uint8_t* string_data_start_ = nullptr; + const uint32_t* post_list_offsets_ = nullptr; + const uint8_t* post_list_data_start_ = nullptr; +}; + +using StringIndexSortPtr = std::unique_ptr; + +inline StringIndexSortPtr +CreateStringIndexSort(const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()) { + return std::make_unique(file_manager_context); +} + +} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/index/StringIndexSortTest.cpp b/internal/core/src/index/StringIndexSortTest.cpp new file mode 100644 index 0000000000..d280536666 --- /dev/null +++ b/internal/core/src/index/StringIndexSortTest.cpp @@ -0,0 +1,607 @@ +#include +#include + +#include "index/StringIndexSort.h" +#include "index/IndexFactory.h" +#include "test_utils/indexbuilder_test_utils.h" + +constexpr int64_t nb = 100; + +namespace milvus { +namespace index { +class StringIndexBaseTest : public ::testing::Test { + protected: + void + SetUp() override { + strs = GenStrArr(nb); + *str_arr.mutable_data() = {strs.begin(), strs.end()}; + } + + protected: + std::vector strs; + schemapb::StringArray str_arr; +}; + +class StringIndexSortTest : public StringIndexBaseTest {}; + +TEST_F(StringIndexSortTest, ConstructorMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + ASSERT_NE(index, nullptr); +} + +TEST_F(StringIndexSortTest, ConstructorMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + ASSERT_NE(index, nullptr); +} + +TEST_F(StringIndexSortTest, BuildMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + ASSERT_EQ(index->Count(), nb); +} + +TEST_F(StringIndexSortTest, BuildMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + ASSERT_EQ(index->Count(), nb); +} + +TEST_F(StringIndexSortTest, InMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(nb, strs.data()); + + // Test with all strings + auto bitset = index->In(strs.size(), strs.data()); + ASSERT_EQ(bitset.size(), strs.size()); + ASSERT_EQ(bitset.count(), strs.size()); + + // Test with subset + std::vector subset = {strs[0], strs[10], strs[20]}; + auto subset_bitset = index->In(subset.size(), subset.data()); + ASSERT_EQ(subset_bitset.size(), strs.size()); + ASSERT_EQ(subset_bitset.count(), 3); + ASSERT_TRUE(subset_bitset[0]); + ASSERT_TRUE(subset_bitset[10]); + ASSERT_TRUE(subset_bitset[20]); +} + +TEST_F(StringIndexSortTest, InMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(nb, strs.data()); + + auto bitset = index->In(strs.size(), strs.data()); + ASSERT_EQ(bitset.size(), strs.size()); + ASSERT_EQ(bitset.count(), strs.size()); +} + +TEST_F(StringIndexSortTest, NotInMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(nb, strs.data()); + + auto bitset = index->NotIn(strs.size(), strs.data()); + ASSERT_EQ(bitset.size(), strs.size()); + ASSERT_EQ(bitset.count(), 0); + + // Test with non-existing strings + std::vector non_existing = {"non_existing_1", + "non_existing_2"}; + auto non_existing_bitset = + index->NotIn(non_existing.size(), non_existing.data()); + ASSERT_EQ(non_existing_bitset.size(), strs.size()); + ASSERT_EQ(non_existing_bitset.count(), strs.size()); +} + +TEST_F(StringIndexSortTest, RangeMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + + // Build with sorted strings for predictable range tests + std::vector sorted_strs = { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}; + index->Build(sorted_strs.size(), sorted_strs.data()); + + // Test LessThan + auto bitset = index->Range("d", OpType::LessThan); + ASSERT_EQ(bitset.count(), 3); // a, b, c + + // Test LessEqual + auto bitset2 = index->Range("d", OpType::LessEqual); + ASSERT_EQ(bitset2.count(), 4); // a, b, c, d + + // Test GreaterThan + auto bitset3 = index->Range("g", OpType::GreaterThan); + ASSERT_EQ(bitset3.count(), 3); // h, i, j + + // Test GreaterEqual + auto bitset4 = index->Range("g", OpType::GreaterEqual); + ASSERT_EQ(bitset4.count(), 4); // g, h, i, j +} + +TEST_F(StringIndexSortTest, RangeBetweenMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + + std::vector sorted_strs = { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}; + index->Build(sorted_strs.size(), sorted_strs.data()); + + // Test inclusive range + auto bitset = index->Range("c", true, "g", true); + ASSERT_EQ(bitset.count(), 5); // c, d, e, f, g + + // Test exclusive range + auto bitset2 = index->Range("c", false, "g", false); + ASSERT_EQ(bitset2.count(), 3); // d, e, f + + // Test mixed + auto bitset3 = index->Range("c", true, "g", false); + ASSERT_EQ(bitset3.count(), 4); // c, d, e, f +} + +TEST_F(StringIndexSortTest, PrefixMatchMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + + std::vector test_strs = { + "apple", "application", "apply", "banana", "band", "cat"}; + index->Build(test_strs.size(), test_strs.data()); + + auto bitset = index->PrefixMatch("app"); + ASSERT_EQ(bitset.count(), 3); // apple, application, apply + + auto bitset2 = index->PrefixMatch("ban"); + ASSERT_EQ(bitset2.count(), 2); // banana, band + + auto bitset3 = index->PrefixMatch("cat"); + ASSERT_EQ(bitset3.count(), 1); // cat + + auto bitset4 = index->PrefixMatch("dog"); + ASSERT_EQ(bitset4.count(), 0); // none +} + +TEST_F(StringIndexSortTest, PrefixMatchMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + + std::vector test_strs = { + "apple", "application", "apply", "banana", "band", "cat"}; + index->Build(test_strs.size(), test_strs.data()); + + auto bitset = index->PrefixMatch("app"); + ASSERT_EQ(bitset.count(), 3); // apple, application, apply +} + +TEST_F(StringIndexSortTest, ReverseLookupMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + + for (size_t i = 0; i < strs.size(); ++i) { + auto result = index->Reverse_Lookup(i); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), strs[i]); + } + + // Test invalid offset + auto result = index->Reverse_Lookup(strs.size() + 100); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(StringIndexSortTest, ReverseLookupMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + + for (size_t i = 0; i < strs.size(); ++i) { + auto result = index->Reverse_Lookup(i); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), strs[i]); + } +} + +TEST_F(StringIndexSortTest, SerializeDeserializeMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + + // Serialize + auto binary_set = index->Serialize(config); + + // Create new index and load + auto new_index = milvus::index::CreateStringIndexSort({}); + new_index->Load(binary_set); + + // Verify data integrity + ASSERT_EQ(new_index->Count(), strs.size()); + + auto bitset = new_index->In(strs.size(), strs.data()); + ASSERT_EQ(bitset.count(), strs.size()); + + for (size_t i = 0; i < strs.size(); ++i) { + auto result = new_index->Reverse_Lookup(i); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), strs[i]); + } +} + +TEST_F(StringIndexSortTest, SerializeDeserializeMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(strs.size(), strs.data()); + + // Serialize + auto binary_set = index->Serialize(config); + + // Create new index and load + auto new_index = milvus::index::CreateStringIndexSort({}); + new_index->Load(binary_set); + + // Verify data integrity + ASSERT_EQ(new_index->Count(), strs.size()); + + auto bitset = new_index->In(strs.size(), strs.data()); + ASSERT_EQ(bitset.count(), strs.size()); +} + +TEST_F(StringIndexSortTest, NullHandlingMemory) { + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + + std::unique_ptr valid(new bool[nb]); + for (int i = 0; i < nb; i++) { + valid[i] = (i % 2 == 0); // Half are valid + } + + index->Build(nb, strs.data(), valid.get()); + + // Test IsNull + auto null_bitset = index->IsNull(); + ASSERT_EQ(null_bitset.count(), nb / 2); + + // Test IsNotNull + auto not_null_bitset = index->IsNotNull(); + ASSERT_EQ(not_null_bitset.count(), nb / 2); + + // Verify they are complementary + for (size_t i = 0; i < nb; ++i) { + ASSERT_NE(null_bitset[i], not_null_bitset[i]); + } +} + +TEST_F(StringIndexSortTest, NullHandlingMmap) { + Config config; + config["mmap_file_path"] = "/tmp/milvus_test"; + auto index = milvus::index::CreateStringIndexSort({}); + + std::unique_ptr valid(new bool[nb]); + for (int i = 0; i < nb; i++) { + valid[i] = (i % 2 == 0); + } + + index->Build(nb, strs.data(), valid.get()); + + auto null_bitset = index->IsNull(); + ASSERT_EQ(null_bitset.count(), nb / 2); + + auto not_null_bitset = index->IsNotNull(); + ASSERT_EQ(not_null_bitset.count(), nb / 2); +} + +TEST_F(StringIndexSortTest, MmapLoadAfterSerialize) { + // Step 1: Build index in memory and serialize + Config build_config; + auto index = milvus::index::CreateStringIndexSort({}); + + std::vector test_strs = { + "apple", + "banana", + "cherry", + "date", + "elderberry", + "fig", + "grape", + "honeydew", + "kiwi", + "lemon", + "apple", + "banana", + "apple" // Include duplicates + }; + index->Build(test_strs.size(), test_strs.data()); + + // Serialize the index + auto binary_set = index->Serialize(build_config); + + // Step 2: Load with mmap configuration + Config mmap_config; + mmap_config[MMAP_FILE_PATH] = "/tmp/test_string_index_sort_mmap.idx"; + + auto mmap_index = milvus::index::CreateStringIndexSort({}); + mmap_index->Load(binary_set, mmap_config); + + // Step 3: Verify functionality with mmap loaded index + // Test Count + ASSERT_EQ(mmap_index->Count(), test_strs.size()); + + // Test In operation + std::vector search_vals = {"apple", "grape", "lemon"}; + auto bitset = mmap_index->In(search_vals.size(), search_vals.data()); + ASSERT_EQ(bitset.count(), + 5); // apple appears 3 times, grape once, lemon once + ASSERT_TRUE(bitset[0]); // apple + ASSERT_TRUE(bitset[6]); // grape + ASSERT_TRUE(bitset[9]); // lemon + ASSERT_TRUE(bitset[10]); // apple (duplicate) + ASSERT_TRUE(bitset[12]); // apple (duplicate) + + // Test NotIn operation + std::vector not_in_vals = {"orange", "pear"}; + auto not_bitset = mmap_index->NotIn(not_in_vals.size(), not_in_vals.data()); + ASSERT_EQ(not_bitset.count(), + test_strs.size()); // All strings should be in result + + // Test Range operation + auto range_bitset = + mmap_index->Range("cherry", milvus::OpType::GreaterEqual); + ASSERT_EQ( + range_bitset.count(), + 8); // cherry, date, elderberry, fig, grape, honeydew, kiwi, lemon + + // Test Range between + auto range_between = mmap_index->Range("banana", true, "grape", true); + ASSERT_EQ(range_between.count(), + 7); // banana(2), cherry, date, elderberry, fig, grape + + // Test PrefixMatch + std::vector prefix_test_strs = { + "app", "apple", "application", "banana", "band"}; + auto prefix_index = milvus::index::CreateStringIndexSort({}); + prefix_index->Build(prefix_test_strs.size(), prefix_test_strs.data()); + auto prefix_binary = prefix_index->Serialize(build_config); + + Config prefix_mmap_config; + prefix_mmap_config[MMAP_FILE_PATH] = "/tmp/test_prefix_mmap.idx"; + auto prefix_mmap_index = milvus::index::CreateStringIndexSort({}); + prefix_mmap_index->Load(prefix_binary, prefix_mmap_config); + + auto prefix_bitset = prefix_mmap_index->PrefixMatch("app"); + ASSERT_EQ(prefix_bitset.count(), 3); // app, apple, application + + // Test Reverse_Lookup + for (size_t i = 0; i < test_strs.size(); ++i) { + auto result = mmap_index->Reverse_Lookup(i); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), test_strs[i]); + } + + // Clean up temp files + std::remove("/tmp/test_string_index_sort_mmap.idx"); + std::remove("/tmp/test_prefix_mmap.idx"); +} + +TEST_F(StringIndexSortTest, LoadWithoutAssembleMmap) { + // Build and serialize index + Config config; + auto index = milvus::index::CreateStringIndexSort({}); + + std::vector test_strs = { + "zebra", "apple", "monkey", "dog", "cat"}; + index->Build(test_strs.size(), test_strs.data()); + + auto binary_set = index->Serialize(config); + + // Load without assemble using mmap + Config mmap_config; + mmap_config[MMAP_FILE_PATH] = "/tmp/test_load_without_assemble.idx"; + + auto mmap_index = milvus::index::CreateStringIndexSort({}); + mmap_index->LoadWithoutAssemble(binary_set, mmap_config); + + // Verify the index works correctly + auto bitset = mmap_index->In(test_strs.size(), test_strs.data()); + ASSERT_EQ(bitset.count(), test_strs.size()); + + // Test that all operations work + auto range_bitset = mmap_index->Range("dog", milvus::OpType::LessEqual); + ASSERT_EQ(range_bitset.count(), 3); // apple, cat, dog + + // Clean up + std::remove("/tmp/test_load_without_assemble.idx"); +} +} // namespace index +} // namespace milvus + +TEST(StringIndexSortStandaloneTest, StringIndexSortBuildAndSearch) { + // Test data + std::vector test_data = {"apple", + "banana", + "cherry", + "date", + "elderberry", + "fig", + "grape", + "honeydew", + "kiwi", + "lemon"}; + auto n = test_data.size(); + + // Test Memory mode + { + milvus::Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data()); + + // Test In operation + std::vector search_vals = {"apple", "grape", "lemon"}; + auto bitset = index->In(search_vals.size(), search_vals.data()); + ASSERT_EQ(bitset.count(), 3); + ASSERT_TRUE(bitset[0]); // apple + ASSERT_TRUE(bitset[6]); // grape + ASSERT_TRUE(bitset[9]); // lemon + + // Test Range operation + auto range_bitset = + index->Range("cherry", milvus::OpType::GreaterEqual); + ASSERT_EQ( + range_bitset.count(), + 8); // cherry, date, elderberry, fig, grape, honeydew, kiwi, lemon + + // Test PrefixMatch + std::vector test_data_prefix = { + "app", "apple", "application", "banana", "band"}; + auto prefix_index = milvus::index::CreateStringIndexSort({}); + prefix_index->Build(test_data_prefix.size(), test_data_prefix.data()); + auto prefix_bitset = prefix_index->PrefixMatch("app"); + ASSERT_EQ(prefix_bitset.count(), 3); // app, apple, application + } + + // Test Mmap mode + { + milvus::Config config; + config["mmap_file_path"] = "/tmp/milvus_scalar_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data()); + + // Test In operation + std::vector search_vals = {"banana", "fig"}; + auto bitset = index->In(search_vals.size(), search_vals.data()); + ASSERT_EQ(bitset.count(), 2); + ASSERT_TRUE(bitset[1]); // banana + ASSERT_TRUE(bitset[5]); // fig + + // Test NotIn operation + auto not_bitset = index->NotIn(search_vals.size(), search_vals.data()); + ASSERT_EQ(not_bitset.count(), n - 2); + ASSERT_FALSE(not_bitset[1]); // banana should not be in NotIn result + ASSERT_FALSE(not_bitset[5]); // fig should not be in NotIn result + } +} + +TEST(StringIndexSortStandaloneTest, StringIndexSortWithNulls) { + std::vector test_data = { + "alpha", "beta", "gamma", "delta", "epsilon"}; + + std::unique_ptr valid_data(new bool[test_data.size()]); + valid_data[0] = true; + valid_data[1] = false; + valid_data[2] = true; + valid_data[3] = false; + valid_data[4] = true; + auto n = test_data.size(); + + // Memory mode with nulls + { + milvus::Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data(), valid_data.get()); + + // Test IsNull + auto null_bitset = index->IsNull(); + ASSERT_EQ(null_bitset.count(), 2); + ASSERT_TRUE(null_bitset[1]); // beta is null + ASSERT_TRUE(null_bitset[3]); // delta is null + + // Test IsNotNull + auto not_null_bitset = index->IsNotNull(); + ASSERT_EQ(not_null_bitset.count(), 3); + ASSERT_TRUE(not_null_bitset[0]); // alpha is not null + ASSERT_TRUE(not_null_bitset[2]); // gamma is not null + ASSERT_TRUE(not_null_bitset[4]); // epsilon is not null + + // Test In with nulls + std::vector search_vals = {"alpha", "beta", "gamma"}; + auto bitset = index->In(search_vals.size(), search_vals.data()); + ASSERT_EQ(bitset.count(), 2); // Only alpha and gamma (beta is null) + ASSERT_TRUE(bitset[0]); // alpha + ASSERT_FALSE(bitset[1]); // beta is null + ASSERT_TRUE(bitset[2]); // gamma + } + + // Mmap mode with nulls + { + milvus::Config config; + config["mmap_file_path"] = "/tmp/milvus_scalar_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data(), valid_data.get()); + + auto null_bitset = index->IsNull(); + ASSERT_EQ(null_bitset.count(), 2); + + auto not_null_bitset = index->IsNotNull(); + ASSERT_EQ(not_null_bitset.count(), 3); + } +} + +TEST(StringIndexSortStandaloneTest, StringIndexSortSerialization) { + std::vector test_data; + for (int i = 0; i < 100; ++i) { + test_data.push_back("str_" + std::to_string(i)); + } + auto n = test_data.size(); + + // Test Memory mode serialization + { + milvus::Config config; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data()); + + // Serialize + auto binary_set = index->Serialize(config); + + // Create new index and deserialize + auto new_index = milvus::index::CreateStringIndexSort({}); + new_index->Load(binary_set); + + // Verify the data + ASSERT_EQ(new_index->Count(), n); + + // Test search on deserialized index + std::vector search_vals = {"str_10", "str_50", "str_90"}; + auto bitset = new_index->In(search_vals.size(), search_vals.data()); + ASSERT_EQ(bitset.count(), 3); + + // Test reverse lookup + for (size_t i = 0; i < n; ++i) { + auto result = new_index->Reverse_Lookup(i); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), test_data[i]); + } + } + + // Test Mmap mode serialization + { + milvus::Config config; + config["mmap_file_path"] = "/tmp/milvus_scalar_test"; + auto index = milvus::index::CreateStringIndexSort({}); + index->Build(n, test_data.data()); + + // Serialize + auto binary_set = index->Serialize(config); + + // Create new index and deserialize + auto new_index = milvus::index::CreateStringIndexSort({}); + new_index->Load(binary_set); + + // Verify the data + ASSERT_EQ(new_index->Count(), n); + + // Test range query on deserialized index + auto bitset = new_index->Range("str_20", true, "str_30", true); + // In lexicographical order: str_20, str_21, ..., str_29, str_3, str_30 + // So we expect more than 11 due to lexicographical ordering + ASSERT_GT(bitset.count(), 0); + } +} diff --git a/internal/util/indexparamcheck/stl_sort_checker.go b/internal/util/indexparamcheck/stl_sort_checker.go index 224617f0f5..50a23ffd4f 100644 --- a/internal/util/indexparamcheck/stl_sort_checker.go +++ b/internal/util/indexparamcheck/stl_sort_checker.go @@ -1,6 +1,8 @@ package indexparamcheck import ( + "fmt" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -17,8 +19,8 @@ func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, elementType sche } func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { - if !typeutil.IsArithmetic(field.GetDataType()) { - return errors.New("STL_SORT are only supported on numeric field") + if !typeutil.IsArithmetic(field.GetDataType()) && !typeutil.IsStringType(field.GetDataType()) { + return errors.New(fmt.Sprintf("STL_SORT are only supported on numeric or varchar field, got %s", field.GetDataType())) } return nil } diff --git a/internal/util/indexparamcheck/stl_sort_checker_test.go b/internal/util/indexparamcheck/stl_sort_checker_test.go index 1fc527bf6c..007adaff08 100644 --- a/internal/util/indexparamcheck/stl_sort_checker_test.go +++ b/internal/util/indexparamcheck/stl_sort_checker_test.go @@ -15,8 +15,8 @@ func Test_STLSORTIndexChecker(t *testing.T) { assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) - assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) assert.Error(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) } diff --git a/tests/go_client/testcases/add_field_test.go b/tests/go_client/testcases/add_field_test.go index 158ac4b98e..efd429b1cc 100644 --- a/tests/go_client/testcases/add_field_test.go +++ b/tests/go_client/testcases/add_field_test.go @@ -245,10 +245,9 @@ func TestIndexAddedField(t *testing.T) { createIndex: index.NewInvertedIndex, }, { - name: "SortedIndex", - indexType: "STL_SORT", - createIndex: index.NewSortedIndex, - expectedError: "STL_SORT are only supported on numeric field", + name: "SortedIndex", + indexType: "STL_SORT", + createIndex: index.NewSortedIndex, }, { name: "TrieIndex", diff --git a/tests/go_client/testcases/index_test.go b/tests/go_client/testcases/index_test.go index e800e78ea2..52ed4a0cf7 100644 --- a/tests/go_client/testcases/index_test.go +++ b/tests/go_client/testcases/index_test.go @@ -468,10 +468,10 @@ func TestCreateSortedScalarIndex(t *testing.T) { idx := index.NewSortedIndex() for _, field := range schema.Fields { if hp.SupportScalarIndexFieldType(field.DataType) { - if field.DataType == entity.FieldTypeVarChar || field.DataType == entity.FieldTypeBool || + if field.DataType == entity.FieldTypeBool || field.DataType == entity.FieldTypeJSON || field.DataType == entity.FieldTypeArray { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx)) - common.CheckErr(t, err, false, "STL_SORT are only supported on numeric field") + require.ErrorContains(t, err, "STL_SORT are only supported on numeric or varchar field") } else { idxTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx)) common.CheckErr(t, err, true) @@ -623,12 +623,12 @@ func TestCreateIndexJsonField(t *testing.T) { errMsg string } inxError := []scalarIndexError{ - {index.NewSortedIndex(), "STL_SORT are only supported on numeric field"}, + {index.NewSortedIndex(), "STL_SORT are only supported on numeric or varchar field"}, {index.NewTrieIndex(), "TRIE are only supported on varchar field"}, } for _, idxErr := range inxError { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultJSONFieldName, idxErr.idx).WithIndexName("json_index")) - common.CheckErr(t, err, false, idxErr.errMsg) + require.ErrorContains(t, err, idxErr.errMsg) } } @@ -649,7 +649,7 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) { errMsg string } inxError := []scalarIndexError{ - {index.NewSortedIndex(), "STL_SORT are only supported on numeric field"}, + {index.NewSortedIndex(), "STL_SORT are only supported on numeric or varchar field"}, {index.NewTrieIndex(), "TRIE are only supported on varchar field"}, } @@ -660,11 +660,11 @@ func TestCreateUnsupportedIndexArrayField(t *testing.T) { if field.DataType == entity.FieldTypeArray { // create vector index _, err1 := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, vectorIdx).WithIndexName("vector_index")) - common.CheckErr(t, err1, false, "index SCANN only supports vector data type") + require.ErrorContains(t, err1, "index SCANN only supports vector data type") // create scalar index _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idxErr.idx)) - common.CheckErr(t, err, false, idxErr.errMsg) + require.ErrorContains(t, err, idxErr.errMsg) } } } diff --git a/tests/python_client/milvus_client/test_milvus_client_index.py b/tests/python_client/milvus_client/test_milvus_client_index.py index 2deeba983c..0eeddb9ecc 100644 --- a/tests/python_client/milvus_client/test_milvus_client_index.py +++ b/tests/python_client/milvus_client/test_milvus_client_index.py @@ -849,13 +849,16 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base): # 3. create index if not_supported_varchar_scalar_index == "TRIE": supported_field_type = "varchar" + got_json_suffix = "" if not_supported_varchar_scalar_index == "STL_SORT": - supported_field_type = "numeric" + supported_field_type = "numeric or varchar" + got_json_suffix = ", got JSON" if not_supported_varchar_scalar_index == "BITMAP": supported_field_type = "bool, int, string and array" not_supported_varchar_scalar_index = "bitmap index" + got_json_suffix = "" error = {ct.err_code: 1100, ct.err_msg: f"{not_supported_varchar_scalar_index} are only supported on " - f"{supported_field_type} field: invalid parameter[expected=valid " + f"{supported_field_type} field{got_json_suffix}: invalid parameter[expected=valid " f"index params][actual=invalid index params]"} self.create_index(client, collection_name, index_params, check_task=CheckTasks.err_res, check_items=error)