mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
feat: impl StructArray -- support embedding searches embeddings in embedding list with element level filter expression (#45830)
issue: https://github.com/milvus-io/milvus/issues/42148 For a vector field inside a STRUCT, since a STRUCT can only appear as the element type of an ARRAY field, the vector field in STRUCT is effectively an array of vectors, i.e. an embedding list. Milvus already supports searching embedding lists with metrics whose names start with the prefix MAX_SIM_. This PR allows Milvus to search embeddings inside an embedding list using the same metrics as normal embedding fields. Each embedding in the list is treated as an independent vector and participates in ANN search. Further, since STRUCT may contain scalar fields that are highly related to the embedding field, this PR introduces an element-level filter expression to refine search results. The grammar of the element-level filter is: element_filter(structFieldName, $[subFieldName] == 3) where $[subFieldName] refers to the value of subFieldName in each element of the STRUCT array structFieldName. It can be combined with existing filter expressions, for example: "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3)" A full example: ``` struct_schema = milvus_client.create_struct_field_schema() struct_schema.add_field("struct_str", DataType.VARCHAR, max_length=65535) struct_schema.add_field("struct_int", DataType.INT32) struct_schema.add_field("struct_float_vec", DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM) schema.add_field( "struct_field", datatype=DataType.ARRAY, element_type=DataType.STRUCT, struct_schema=struct_schema, max_capacity=1000, ) ... filter = "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3 && $[struct_str] == 'abc')" res = milvus_client.search( COLLECTION_NAME, data=query_embeddings, limit=10, anns_field="struct_field[struct_float_vec]", filter=filter, output_fields=["struct_field[struct_int]", "varcharField"], ) ``` TODO: 1. When an `element_filter` expression is used, a regular filter expression must also be present. Remove this restriction. 2. Implement `element_filter` expressions in the `query`. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
This commit is contained in:
parent
ca2e27f576
commit
f6f716bcfd
292
internal/core/src/common/ArrayOffsets.cpp
Normal file
292
internal/core/src/common/ArrayOffsets.cpp
Normal file
@ -0,0 +1,292 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ArrayOffsets.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "log/Log.h"
|
||||
#include "common/EasyAssert.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ArrayOffsetsSealed::ElementIDToRowID(int32_t elem_id) const {
|
||||
assert(elem_id >= 0 && elem_id < GetTotalElementCount());
|
||||
|
||||
int32_t row_id = element_row_ids_[elem_id];
|
||||
// Compute elem_idx: elem_idx = elem_id - start_of_this_row
|
||||
int32_t elem_idx = elem_id - row_to_element_start_[row_id];
|
||||
return {row_id, elem_idx};
|
||||
}
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ArrayOffsetsSealed::ElementIDRangeOfRow(int32_t row_id) const {
|
||||
int32_t row_count = GetRowCount();
|
||||
assert(row_id >= 0 && row_id <= row_count);
|
||||
|
||||
if (row_id == row_count) {
|
||||
auto total = row_to_element_start_[row_count];
|
||||
return {total, total};
|
||||
}
|
||||
return {row_to_element_start_[row_id], row_to_element_start_[row_id + 1]};
|
||||
}
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
ArrayOffsetsSealed::RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const {
|
||||
int64_t row_count = GetRowCount();
|
||||
int64_t element_count = GetTotalElementCount();
|
||||
TargetBitmap element_bitset(element_count);
|
||||
TargetBitmap valid_element_bitset(element_count);
|
||||
|
||||
for (int64_t row_id = 0; row_id < row_count; ++row_id) {
|
||||
int64_t start = row_to_element_start_[row_id];
|
||||
int64_t end = row_to_element_start_[row_id + 1];
|
||||
if (start < end) {
|
||||
element_bitset.set(start, end - start, row_bitset[row_id]);
|
||||
valid_element_bitset.set(
|
||||
start, end - start, valid_row_bitset[row_id]);
|
||||
}
|
||||
}
|
||||
|
||||
return {std::move(element_bitset), std::move(valid_element_bitset)};
|
||||
}
|
||||
|
||||
std::shared_ptr<ArrayOffsetsSealed>
|
||||
ArrayOffsetsSealed::BuildFromSegment(const void* segment,
|
||||
const FieldMeta& field_meta) {
|
||||
auto seg = static_cast<const segcore::SegmentInternalInterface*>(segment);
|
||||
|
||||
int64_t row_count = seg->get_row_count();
|
||||
if (row_count == 0) {
|
||||
LOG_INFO(
|
||||
"ArrayOffsetsSealed::BuildFromSegment: empty segment for struct "
|
||||
"'{}'",
|
||||
field_meta.get_name().get());
|
||||
return std::make_shared<ArrayOffsetsSealed>(std::vector<int32_t>{},
|
||||
std::vector<int32_t>{0});
|
||||
}
|
||||
|
||||
FieldId field_id = field_meta.get_id();
|
||||
auto data_type = field_meta.get_data_type();
|
||||
|
||||
std::vector<int32_t> element_row_ids;
|
||||
// Size is row_count + 1, last element stores total_element_count
|
||||
std::vector<int32_t> row_to_element_start(row_count + 1);
|
||||
|
||||
auto temp_op_ctx = std::make_unique<OpContext>();
|
||||
auto op_ctx_ptr = temp_op_ctx.get();
|
||||
|
||||
int64_t num_chunks = seg->num_chunk(field_id);
|
||||
int32_t current_row_id = 0;
|
||||
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
|
||||
auto pin_wrapper = seg->chunk_view<VectorArrayView>(
|
||||
op_ctx_ptr, field_id, chunk_id);
|
||||
const auto& [vector_array_views, valid_flags] = pin_wrapper.get();
|
||||
|
||||
for (size_t i = 0; i < vector_array_views.size(); ++i) {
|
||||
int32_t array_len = 0;
|
||||
if (valid_flags.empty() || valid_flags[i]) {
|
||||
array_len = vector_array_views[i].length();
|
||||
}
|
||||
|
||||
// Record the start position for this row
|
||||
row_to_element_start[current_row_id] = element_row_ids.size();
|
||||
|
||||
// Add row_id for each element (elem_idx computed on access)
|
||||
for (int32_t j = 0; j < array_len; ++j) {
|
||||
element_row_ids.emplace_back(current_row_id);
|
||||
}
|
||||
|
||||
current_row_id++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
|
||||
auto pin_wrapper =
|
||||
seg->chunk_view<ArrayView>(op_ctx_ptr, field_id, chunk_id);
|
||||
const auto& [array_views, valid_flags] = pin_wrapper.get();
|
||||
|
||||
for (size_t i = 0; i < array_views.size(); ++i) {
|
||||
int32_t array_len = 0;
|
||||
if (valid_flags.empty() || valid_flags[i]) {
|
||||
array_len = array_views[i].length();
|
||||
}
|
||||
|
||||
// Record the start position for this row
|
||||
row_to_element_start[current_row_id] = element_row_ids.size();
|
||||
|
||||
// Add row_id for each element (elem_idx computed on access)
|
||||
for (int32_t j = 0; j < array_len; ++j) {
|
||||
element_row_ids.emplace_back(current_row_id);
|
||||
}
|
||||
|
||||
current_row_id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store total element count as the last entry
|
||||
row_to_element_start[row_count] = element_row_ids.size();
|
||||
|
||||
AssertInfo(current_row_id == row_count,
|
||||
"Row count mismatch: expected {}, got {}",
|
||||
row_count,
|
||||
current_row_id);
|
||||
|
||||
int64_t total_elements = element_row_ids.size();
|
||||
|
||||
LOG_INFO(
|
||||
"ArrayOffsetsSealed::BuildFromSegment: struct_name='{}', "
|
||||
"field_id={}, row_count={}, total_elements={}",
|
||||
field_meta.get_name().get(),
|
||||
field_meta.get_id().get(),
|
||||
row_count,
|
||||
total_elements);
|
||||
|
||||
auto result = std::make_shared<ArrayOffsetsSealed>(
|
||||
std::move(element_row_ids), std::move(row_to_element_start));
|
||||
result->resource_size_ = 4 * (row_count + 1) + 4 * total_elements;
|
||||
cachinglayer::Manager::GetInstance().ChargeLoadedResource(
|
||||
cachinglayer::ResourceUsage{result->resource_size_, 0});
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ArrayOffsetsGrowing::ElementIDToRowID(int32_t elem_id) const {
|
||||
std::shared_lock lock(mutex_);
|
||||
assert(elem_id >= 0 &&
|
||||
elem_id < static_cast<int32_t>(element_row_ids_.size()));
|
||||
int32_t row_id = element_row_ids_[elem_id];
|
||||
// Compute elem_idx: elem_idx = elem_id - start_of_this_row
|
||||
int32_t elem_idx = elem_id - row_to_element_start_[row_id];
|
||||
return {row_id, elem_idx};
|
||||
}
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ArrayOffsetsGrowing::ElementIDRangeOfRow(int32_t row_id) const {
|
||||
std::shared_lock lock(mutex_);
|
||||
assert(row_id >= 0 && row_id <= committed_row_count_);
|
||||
|
||||
if (row_id == committed_row_count_) {
|
||||
auto total = row_to_element_start_[committed_row_count_];
|
||||
return {total, total};
|
||||
}
|
||||
return {row_to_element_start_[row_id], row_to_element_start_[row_id + 1]};
|
||||
}
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
ArrayOffsetsGrowing::RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const {
|
||||
std::shared_lock lock(mutex_);
|
||||
|
||||
int64_t element_count = element_row_ids_.size();
|
||||
TargetBitmap element_bitset(element_count);
|
||||
TargetBitmap valid_element_bitset(element_count);
|
||||
|
||||
// Direct access to element_row_ids_, no virtual function calls
|
||||
for (size_t elem_id = 0; elem_id < element_row_ids_.size(); ++elem_id) {
|
||||
auto row_id = element_row_ids_[elem_id];
|
||||
element_bitset[elem_id] = row_bitset[row_id];
|
||||
valid_element_bitset[elem_id] = valid_row_bitset[row_id];
|
||||
}
|
||||
|
||||
return {std::move(element_bitset), std::move(valid_element_bitset)};
|
||||
}
|
||||
|
||||
void
|
||||
ArrayOffsetsGrowing::Insert(int64_t row_id_start,
|
||||
const int32_t* array_lengths,
|
||||
int64_t count) {
|
||||
std::unique_lock lock(mutex_);
|
||||
|
||||
row_to_element_start_.reserve(row_id_start + count + 1);
|
||||
|
||||
int32_t original_committed_count = committed_row_count_;
|
||||
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
int32_t row_id = row_id_start + i;
|
||||
int32_t array_len = array_lengths[i];
|
||||
|
||||
if (row_id == committed_row_count_) {
|
||||
// Record the start position for this row
|
||||
// If sentinel exists at current position, overwrite it; otherwise push_back
|
||||
if (row_to_element_start_.size() >
|
||||
static_cast<size_t>(committed_row_count_)) {
|
||||
row_to_element_start_[committed_row_count_] =
|
||||
element_row_ids_.size();
|
||||
} else {
|
||||
row_to_element_start_.push_back(element_row_ids_.size());
|
||||
}
|
||||
|
||||
// Add row_id for each element (elem_idx computed on access)
|
||||
for (int32_t j = 0; j < array_len; ++j) {
|
||||
element_row_ids_.emplace_back(row_id);
|
||||
}
|
||||
|
||||
committed_row_count_++;
|
||||
} else {
|
||||
pending_rows_[row_id] = {row_id, array_len};
|
||||
}
|
||||
}
|
||||
|
||||
DrainPendingRows();
|
||||
|
||||
// Update the sentinel (total element count) only if we committed new rows
|
||||
if (committed_row_count_ > original_committed_count) {
|
||||
if (row_to_element_start_.size() ==
|
||||
static_cast<size_t>(committed_row_count_)) {
|
||||
row_to_element_start_.push_back(element_row_ids_.size());
|
||||
} else {
|
||||
row_to_element_start_[committed_row_count_] =
|
||||
element_row_ids_.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ArrayOffsetsGrowing::DrainPendingRows() {
|
||||
while (true) {
|
||||
auto it = pending_rows_.find(committed_row_count_);
|
||||
if (it == pending_rows_.end()) {
|
||||
break;
|
||||
}
|
||||
|
||||
const auto& pending = it->second;
|
||||
|
||||
// If sentinel exists at current position, overwrite it; otherwise push_back
|
||||
if (row_to_element_start_.size() >
|
||||
static_cast<size_t>(committed_row_count_)) {
|
||||
row_to_element_start_[committed_row_count_] =
|
||||
element_row_ids_.size();
|
||||
} else {
|
||||
row_to_element_start_.push_back(element_row_ids_.size());
|
||||
}
|
||||
|
||||
for (int32_t j = 0; j < pending.array_len; ++j) {
|
||||
element_row_ids_.emplace_back(static_cast<int32_t>(pending.row_id));
|
||||
}
|
||||
|
||||
committed_row_count_++;
|
||||
|
||||
pending_rows_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
164
internal/core/src/common/ArrayOffsets.h
Normal file
164
internal/core/src/common/ArrayOffsets.h
Normal file
@ -0,0 +1,164 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <shared_mutex>
|
||||
#include "cachinglayer/Manager.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/FieldMeta.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
class IArrayOffsets {
|
||||
public:
|
||||
virtual ~IArrayOffsets() = default;
|
||||
|
||||
virtual int64_t
|
||||
GetRowCount() const = 0;
|
||||
|
||||
virtual int64_t
|
||||
GetTotalElementCount() const = 0;
|
||||
|
||||
// Convert element ID to row ID
|
||||
// returns pair of <row_id, element_index>
|
||||
// element id is contiguous between rows
|
||||
virtual std::pair<int32_t, int32_t>
|
||||
ElementIDToRowID(int32_t elem_id) const = 0;
|
||||
|
||||
// Convert row ID to element ID range
|
||||
// elements with id in [ret.first, ret.last) belong to row_id
|
||||
virtual std::pair<int32_t, int32_t>
|
||||
ElementIDRangeOfRow(int32_t row_id) const = 0;
|
||||
|
||||
// Convert row-level bitsets to element-level bitsets
|
||||
virtual std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const = 0;
|
||||
};
|
||||
|
||||
class ArrayOffsetsSealed : public IArrayOffsets {
|
||||
friend class ArrayOffsetsTest;
|
||||
|
||||
public:
|
||||
ArrayOffsetsSealed() : element_row_ids_(), row_to_element_start_({0}) {
|
||||
}
|
||||
|
||||
ArrayOffsetsSealed(std::vector<int32_t> element_row_ids,
|
||||
std::vector<int32_t> row_to_element_start)
|
||||
: element_row_ids_(std::move(element_row_ids)),
|
||||
row_to_element_start_(std::move(row_to_element_start)) {
|
||||
AssertInfo(!row_to_element_start_.empty(),
|
||||
"row_to_element_start must have at least one element");
|
||||
}
|
||||
|
||||
~ArrayOffsetsSealed() {
|
||||
cachinglayer::Manager::GetInstance().RefundLoadedResource(
|
||||
{resource_size_, 0});
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetRowCount() const override {
|
||||
return static_cast<int64_t>(row_to_element_start_.size()) - 1;
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetTotalElementCount() const override {
|
||||
return element_row_ids_.size();
|
||||
}
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ElementIDToRowID(int32_t elem_id) const override;
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ElementIDRangeOfRow(int32_t row_id) const override;
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const override;
|
||||
|
||||
static std::shared_ptr<ArrayOffsetsSealed>
|
||||
BuildFromSegment(const void* segment, const FieldMeta& field_meta);
|
||||
|
||||
private:
|
||||
const std::vector<int32_t> element_row_ids_;
|
||||
const std::vector<int32_t> row_to_element_start_;
|
||||
int64_t resource_size_{0};
|
||||
};
|
||||
|
||||
class ArrayOffsetsGrowing : public IArrayOffsets {
|
||||
public:
|
||||
ArrayOffsetsGrowing() = default;
|
||||
|
||||
void
|
||||
Insert(int64_t row_id_start, const int32_t* array_lengths, int64_t count);
|
||||
|
||||
int64_t
|
||||
GetRowCount() const override {
|
||||
std::shared_lock lock(mutex_);
|
||||
return committed_row_count_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetTotalElementCount() const override {
|
||||
std::shared_lock lock(mutex_);
|
||||
return element_row_ids_.size();
|
||||
}
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ElementIDToRowID(int32_t elem_id) const override;
|
||||
|
||||
std::pair<int32_t, int32_t>
|
||||
ElementIDRangeOfRow(int32_t row_id) const override;
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const override;
|
||||
|
||||
private:
|
||||
struct PendingRow {
|
||||
int64_t row_id;
|
||||
int32_t array_len;
|
||||
};
|
||||
|
||||
void
|
||||
DrainPendingRows();
|
||||
|
||||
private:
|
||||
std::vector<int32_t> element_row_ids_;
|
||||
|
||||
std::vector<int32_t> row_to_element_start_;
|
||||
|
||||
// Number of rows committed (contiguous from 0)
|
||||
int32_t committed_row_count_ = 0;
|
||||
|
||||
// Pending rows waiting for earlier rows to complete
|
||||
// Key: row_id, automatically sorted
|
||||
std::map<int64_t, PendingRow> pending_rows_;
|
||||
|
||||
// Protects all member variables
|
||||
mutable std::shared_mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace milvus
|
||||
427
internal/core/src/common/ArrayOffsetsTest.cpp
Normal file
427
internal/core/src/common/ArrayOffsetsTest.cpp
Normal file
@ -0,0 +1,427 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "common/ArrayOffsets.h"
|
||||
|
||||
using namespace milvus;
|
||||
|
||||
class ArrayOffsetsTest : public ::testing::Test {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ArrayOffsetsTest, SealedBasic) {
|
||||
// Create a simple ArrayOffsetsSealed manually
|
||||
// row 0: 2 elements (elem 0, 1)
|
||||
// row 1: 3 elements (elem 2, 3, 4)
|
||||
// row 2: 1 element (elem 5)
|
||||
ArrayOffsetsSealed offsets(
|
||||
{0, 0, 1, 1, 1, 2}, // element_row_ids
|
||||
{0, 2, 5, 6} // row_to_element_start (size = row_count + 1)
|
||||
);
|
||||
|
||||
// Test GetRowCount
|
||||
EXPECT_EQ(offsets.GetRowCount(), 3);
|
||||
|
||||
// Test GetTotalElementCount
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
|
||||
|
||||
// Test ElementIDToRowID
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(1);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 1);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
|
||||
EXPECT_EQ(row_id, 1);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(4);
|
||||
EXPECT_EQ(row_id, 1);
|
||||
EXPECT_EQ(elem_idx, 2);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
|
||||
EXPECT_EQ(row_id, 2);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
|
||||
// Test ElementIDRangeOfRow
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(0);
|
||||
EXPECT_EQ(start, 0);
|
||||
EXPECT_EQ(end, 2);
|
||||
}
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(1);
|
||||
EXPECT_EQ(start, 2);
|
||||
EXPECT_EQ(end, 5);
|
||||
}
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(2);
|
||||
EXPECT_EQ(start, 5);
|
||||
EXPECT_EQ(end, 6);
|
||||
}
|
||||
// When row_id == row_count, return (total_elements, total_elements)
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(3);
|
||||
EXPECT_EQ(start, 6);
|
||||
EXPECT_EQ(end, 6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, SealedRowBitsetToElementBitset) {
|
||||
ArrayOffsetsSealed offsets({0, 0, 1, 1, 1, 2}, // element_row_ids
|
||||
{0, 2, 5, 6} // row_to_element_start
|
||||
);
|
||||
|
||||
// row_bitset: row 0 = true, row 1 = false, row 2 = true
|
||||
TargetBitmap row_bitset(3);
|
||||
row_bitset[0] = true;
|
||||
row_bitset[1] = false;
|
||||
row_bitset[2] = true;
|
||||
|
||||
TargetBitmap valid_row_bitset(3, true);
|
||||
|
||||
TargetBitmapView row_view(row_bitset.data(), row_bitset.size());
|
||||
TargetBitmapView valid_view(valid_row_bitset.data(),
|
||||
valid_row_bitset.size());
|
||||
|
||||
auto [elem_bitset, valid_elem_bitset] =
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view);
|
||||
|
||||
EXPECT_EQ(elem_bitset.size(), 6);
|
||||
// Elements of row 0 (elem 0, 1) should be true
|
||||
EXPECT_TRUE(elem_bitset[0]);
|
||||
EXPECT_TRUE(elem_bitset[1]);
|
||||
// Elements of row 1 (elem 2, 3, 4) should be false
|
||||
EXPECT_FALSE(elem_bitset[2]);
|
||||
EXPECT_FALSE(elem_bitset[3]);
|
||||
EXPECT_FALSE(elem_bitset[4]);
|
||||
// Elements of row 2 (elem 5) should be true
|
||||
EXPECT_TRUE(elem_bitset[5]);
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, SealedEmptyArrays) {
|
||||
// Test with some rows having empty arrays
|
||||
// row 1 and row 3 are empty
|
||||
ArrayOffsetsSealed offsets({0, 0, 2, 2, 2}, // element_row_ids
|
||||
{0, 2, 2, 5, 5} // row_to_element_start
|
||||
);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 4);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
|
||||
|
||||
// Row 1 has no elements
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(1);
|
||||
EXPECT_EQ(start, 2);
|
||||
EXPECT_EQ(end, 2); // empty range
|
||||
}
|
||||
// Row 3 has no elements
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(3);
|
||||
EXPECT_EQ(start, 5);
|
||||
EXPECT_EQ(end, 5); // empty range
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingBasicInsert) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert rows in order
|
||||
std::vector<int32_t> lens1 = {2}; // row 0: 2 elements
|
||||
offsets.Insert(0, lens1.data(), 1);
|
||||
|
||||
std::vector<int32_t> lens2 = {3}; // row 1: 3 elements
|
||||
offsets.Insert(1, lens2.data(), 1);
|
||||
|
||||
std::vector<int32_t> lens3 = {1}; // row 2: 1 element
|
||||
offsets.Insert(2, lens3.data(), 1);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 3);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
|
||||
|
||||
// Test ElementIDToRowID
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
|
||||
EXPECT_EQ(row_id, 1);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
|
||||
EXPECT_EQ(row_id, 2);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
|
||||
// Test ElementIDRangeOfRow
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(0);
|
||||
EXPECT_EQ(start, 0);
|
||||
EXPECT_EQ(end, 2);
|
||||
}
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(1);
|
||||
EXPECT_EQ(start, 2);
|
||||
EXPECT_EQ(end, 5);
|
||||
}
|
||||
// When row_id == row_count, return (total_elements, total_elements)
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(3);
|
||||
EXPECT_EQ(start, 6);
|
||||
EXPECT_EQ(end, 6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingBatchInsert) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert multiple rows at once
|
||||
std::vector<int32_t> lens = {2, 3, 1}; // row 0, 1, 2
|
||||
offsets.Insert(0, lens.data(), 3);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 3);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
|
||||
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(0);
|
||||
EXPECT_EQ(start, 0);
|
||||
EXPECT_EQ(end, 2);
|
||||
}
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(1);
|
||||
EXPECT_EQ(start, 2);
|
||||
EXPECT_EQ(end, 5);
|
||||
}
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(2);
|
||||
EXPECT_EQ(start, 5);
|
||||
EXPECT_EQ(end, 6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingOutOfOrderInsert) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert out of order - row 2 arrives before row 1
|
||||
std::vector<int32_t> lens0 = {2};
|
||||
offsets.Insert(0, lens0.data(), 1); // row 0
|
||||
|
||||
std::vector<int32_t> lens2 = {1};
|
||||
offsets.Insert(2, lens2.data(), 1); // row 2 (pending)
|
||||
|
||||
// row 1 not inserted yet, so only row 0 should be committed
|
||||
EXPECT_EQ(offsets.GetRowCount(), 1);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 2);
|
||||
|
||||
// Now insert row 1, which should drain pending row 2
|
||||
std::vector<int32_t> lens1 = {3};
|
||||
offsets.Insert(1, lens1.data(), 1); // row 1
|
||||
|
||||
// Now all 3 rows should be committed
|
||||
EXPECT_EQ(offsets.GetRowCount(), 3);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
|
||||
|
||||
// Verify order is correct
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
|
||||
EXPECT_EQ(row_id, 1);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
|
||||
EXPECT_EQ(row_id, 2);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingEmptyArrays) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert rows with some empty arrays
|
||||
std::vector<int32_t> lens = {2, 0, 3, 0}; // row 1 and row 3 are empty
|
||||
offsets.Insert(0, lens.data(), 4);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 4);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
|
||||
|
||||
// Row 1 has no elements
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(1);
|
||||
EXPECT_EQ(start, 2);
|
||||
EXPECT_EQ(end, 2);
|
||||
}
|
||||
// Row 3 has no elements
|
||||
{
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(3);
|
||||
EXPECT_EQ(start, 5);
|
||||
EXPECT_EQ(end, 5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingRowBitsetToElementBitset) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
std::vector<int32_t> lens = {2, 3, 1};
|
||||
offsets.Insert(0, lens.data(), 3);
|
||||
|
||||
TargetBitmap row_bitset(3);
|
||||
row_bitset[0] = true;
|
||||
row_bitset[1] = false;
|
||||
row_bitset[2] = true;
|
||||
|
||||
TargetBitmap valid_row_bitset(3, true);
|
||||
|
||||
TargetBitmapView row_view(row_bitset.data(), row_bitset.size());
|
||||
TargetBitmapView valid_view(valid_row_bitset.data(),
|
||||
valid_row_bitset.size());
|
||||
|
||||
auto [elem_bitset, valid_elem_bitset] =
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view);
|
||||
|
||||
EXPECT_EQ(elem_bitset.size(), 6);
|
||||
EXPECT_TRUE(elem_bitset[0]);
|
||||
EXPECT_TRUE(elem_bitset[1]);
|
||||
EXPECT_FALSE(elem_bitset[2]);
|
||||
EXPECT_FALSE(elem_bitset[3]);
|
||||
EXPECT_FALSE(elem_bitset[4]);
|
||||
EXPECT_TRUE(elem_bitset[5]);
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, GrowingConcurrentRead) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert initial data
|
||||
std::vector<int32_t> lens = {2, 3, 1};
|
||||
offsets.Insert(0, lens.data(), 3);
|
||||
|
||||
// Concurrent reads should be safe
|
||||
std::vector<std::thread> threads;
|
||||
for (int t = 0; t < 4; ++t) {
|
||||
threads.emplace_back([&offsets]() {
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
auto row_count = offsets.GetRowCount();
|
||||
auto elem_count = offsets.GetTotalElementCount();
|
||||
EXPECT_GE(row_count, 0);
|
||||
EXPECT_GE(elem_count, 0);
|
||||
|
||||
if (row_count > 0) {
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(0);
|
||||
EXPECT_GE(start, 0);
|
||||
EXPECT_GE(end, start);
|
||||
}
|
||||
|
||||
if (elem_count > 0) {
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
|
||||
EXPECT_GE(row_id, 0);
|
||||
EXPECT_GE(elem_idx, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, SingleRow) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
std::vector<int32_t> lens = {5};
|
||||
offsets.Insert(0, lens.data(), 1);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 1);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(i);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, i);
|
||||
}
|
||||
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(0);
|
||||
EXPECT_EQ(start, 0);
|
||||
EXPECT_EQ(end, 5);
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, SingleElementPerRow) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
std::vector<int32_t> lens = {1, 1, 1, 1, 1};
|
||||
offsets.Insert(0, lens.data(), 5);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 5);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(i);
|
||||
EXPECT_EQ(row_id, i);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
|
||||
auto [start, end] = offsets.ElementIDRangeOfRow(i);
|
||||
EXPECT_EQ(start, i);
|
||||
EXPECT_EQ(end, i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ArrayOffsetsTest, LargeArrayLength) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Single row with many elements
|
||||
std::vector<int32_t> lens = {10000};
|
||||
offsets.Insert(0, lens.data(), 1);
|
||||
|
||||
EXPECT_EQ(offsets.GetRowCount(), 1);
|
||||
EXPECT_EQ(offsets.GetTotalElementCount(), 10000);
|
||||
|
||||
// Test first, middle, and last elements
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 0);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5000);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 5000);
|
||||
}
|
||||
{
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(9999);
|
||||
EXPECT_EQ(row_id, 0);
|
||||
EXPECT_EQ(elem_idx, 9999);
|
||||
}
|
||||
}
|
||||
115
internal/core/src/common/ElementFilterIterator.cpp
Normal file
115
internal/core/src/common/ElementFilterIterator.cpp
Normal file
@ -0,0 +1,115 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ElementFilterIterator.h"
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "exec/expression/EvalCtx.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
ElementFilterIterator::ElementFilterIterator(
|
||||
std::shared_ptr<VectorIterator> base_iterator,
|
||||
exec::ExecContext* exec_context,
|
||||
exec::ExprSet* expr_set)
|
||||
: base_iterator_(std::move(base_iterator)),
|
||||
exec_context_(exec_context),
|
||||
expr_set_(expr_set) {
|
||||
AssertInfo(base_iterator_ != nullptr, "Base iterator cannot be null");
|
||||
AssertInfo(exec_context_ != nullptr, "Exec context cannot be null");
|
||||
AssertInfo(expr_set_ != nullptr, "ExprSet cannot be null");
|
||||
}
|
||||
|
||||
bool
|
||||
ElementFilterIterator::HasNext() {
|
||||
// If cache is empty and base iterator has more, fetch more
|
||||
while (filtered_buffer_.empty() && base_iterator_->HasNext()) {
|
||||
FetchAndFilterBatch();
|
||||
}
|
||||
return !filtered_buffer_.empty();
|
||||
}
|
||||
|
||||
std::optional<std::pair<int64_t, float>>
|
||||
ElementFilterIterator::Next() {
|
||||
if (!HasNext()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto result = filtered_buffer_.front();
|
||||
filtered_buffer_.pop_front();
|
||||
return result;
|
||||
}
|
||||
|
||||
void
|
||||
ElementFilterIterator::FetchAndFilterBatch() {
|
||||
constexpr size_t kBatchSize = 1024;
|
||||
|
||||
// Step 1: Fetch a batch from base iterator (up to kBatchSize elements)
|
||||
element_ids_buffer_.clear();
|
||||
distances_buffer_.clear();
|
||||
element_ids_buffer_.reserve(kBatchSize);
|
||||
distances_buffer_.reserve(kBatchSize);
|
||||
|
||||
while (base_iterator_->HasNext() &&
|
||||
element_ids_buffer_.size() < kBatchSize) {
|
||||
auto pair = base_iterator_->Next();
|
||||
if (pair.has_value()) {
|
||||
element_ids_buffer_.push_back(static_cast<int32_t>(pair->first));
|
||||
distances_buffer_.push_back(pair->second);
|
||||
}
|
||||
}
|
||||
|
||||
// If no elements fetched, base iterator is exhausted
|
||||
if (element_ids_buffer_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 2: Batch evaluate element-level expression
|
||||
exec::EvalCtx eval_ctx(exec_context_, expr_set_, &element_ids_buffer_);
|
||||
std::vector<VectorPtr> results;
|
||||
|
||||
// Evaluate the expression set (should contain only element_expr)
|
||||
expr_set_->Eval(0, 1, true, eval_ctx, results);
|
||||
|
||||
AssertInfo(results.size() == 1 && results[0] != nullptr,
|
||||
"ElementFilterIterator: expression evaluation should return "
|
||||
"exactly one result");
|
||||
|
||||
// Step 3: Extract evaluation results as bitmap
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
|
||||
AssertInfo(col_vec != nullptr,
|
||||
"ElementFilterIterator: result should be ColumnVector");
|
||||
AssertInfo(col_vec->IsBitmap(),
|
||||
"ElementFilterIterator: result should be bitmap");
|
||||
|
||||
auto col_vec_size = col_vec->size();
|
||||
AssertInfo(col_vec_size == element_ids_buffer_.size(),
|
||||
"ElementFilterIterator: evaluation result size mismatch");
|
||||
|
||||
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
|
||||
|
||||
// Step 4: Filter elements based on evaluation results and cache them
|
||||
for (size_t i = 0; i < element_ids_buffer_.size(); ++i) {
|
||||
if (bitsetview[i]) {
|
||||
// Element passes filter, cache it
|
||||
filtered_buffer_.emplace_back(element_ids_buffer_[i],
|
||||
distances_buffer_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
74
internal/core/src/common/ElementFilterIterator.h
Normal file
74
internal/core/src/common/ElementFilterIterator.h
Normal file
@ -0,0 +1,74 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "common/QueryResult.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
// Forward declarations
|
||||
namespace exec {
|
||||
class ExecContext;
|
||||
class ExprSet;
|
||||
} // namespace exec
|
||||
|
||||
class ElementFilterIterator : public VectorIterator {
|
||||
public:
|
||||
ElementFilterIterator(std::shared_ptr<VectorIterator> base_iterator,
|
||||
exec::ExecContext* exec_context,
|
||||
exec::ExprSet* expr_set);
|
||||
|
||||
bool
|
||||
HasNext() override;
|
||||
|
||||
std::optional<std::pair<int64_t, float>>
|
||||
Next() override;
|
||||
|
||||
private:
|
||||
// Fetch a batch from base iterator, evaluate expression, and cache results
|
||||
// Steps:
|
||||
// 1. Fetch up to batch_size elements from base_iterator
|
||||
// 2. Batch evaluate element_expr on fetched elements
|
||||
// 3. Filter elements based on evaluation results
|
||||
// 4. Cache passing elements in filtered_buffer_
|
||||
void
|
||||
FetchAndFilterBatch();
|
||||
|
||||
// Base iterator to fetch elements from
|
||||
std::shared_ptr<VectorIterator> base_iterator_;
|
||||
|
||||
// Execution context for expression evaluation
|
||||
exec::ExecContext* exec_context_;
|
||||
|
||||
// Expression set containing element-level filter expression
|
||||
exec::ExprSet* expr_set_;
|
||||
|
||||
// Cache of filtered elements ready to be consumed
|
||||
std::deque<std::pair<int64_t, float>> filtered_buffer_;
|
||||
|
||||
// Reusable buffers for batch fetching (avoid repeated allocations)
|
||||
FixedVector<int32_t> element_ids_buffer_;
|
||||
FixedVector<float> distances_buffer_;
|
||||
};
|
||||
|
||||
} // namespace milvus
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ArrayOffsets.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/Types.h"
|
||||
#include "knowhere/config.h"
|
||||
@ -46,6 +47,13 @@ struct SearchInfo {
|
||||
std::optional<std::string> json_path_;
|
||||
std::optional<milvus::DataType> json_type_;
|
||||
bool strict_cast_{false};
|
||||
std::shared_ptr<const IArrayOffsets> array_offsets_{
|
||||
nullptr}; // For element-level search
|
||||
|
||||
bool
|
||||
element_level() const {
|
||||
return array_offsets_ != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
using SearchInfoPtr = std::shared_ptr<SearchInfo>;
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include <NamedType/named_type.hpp>
|
||||
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "pb/schema.pb.h"
|
||||
#include "knowhere/index/index_node.h"
|
||||
|
||||
@ -122,16 +123,37 @@ struct OffsetDisPairComparator {
|
||||
return left->GetOffDis().first < right->GetOffDis().first;
|
||||
}
|
||||
};
|
||||
struct VectorIterator {
|
||||
|
||||
class VectorIterator {
|
||||
public:
|
||||
VectorIterator(int chunk_count,
|
||||
const std::vector<int64_t>& total_rows_until_chunk = {})
|
||||
virtual ~VectorIterator() = default;
|
||||
|
||||
virtual bool
|
||||
HasNext() = 0;
|
||||
|
||||
virtual std::optional<std::pair<int64_t, float>>
|
||||
Next() = 0;
|
||||
};
|
||||
|
||||
// Multi-way merge iterator for vector search results from multiple chunks
|
||||
//
|
||||
// Merges knowhere iterators from different chunks using a min-heap,
|
||||
// returning results in distance-sorted order.
|
||||
class ChunkMergeIterator : public VectorIterator {
|
||||
public:
|
||||
ChunkMergeIterator(int chunk_count,
|
||||
const std::vector<int64_t>& total_rows_until_chunk = {})
|
||||
: total_rows_until_chunk_(total_rows_until_chunk) {
|
||||
iterators_.reserve(chunk_count);
|
||||
}
|
||||
|
||||
bool
|
||||
HasNext() override {
|
||||
return !heap_.empty();
|
||||
}
|
||||
|
||||
std::optional<std::pair<int64_t, float>>
|
||||
Next() {
|
||||
Next() override {
|
||||
if (!heap_.empty()) {
|
||||
auto top = heap_.top();
|
||||
heap_.pop();
|
||||
@ -145,10 +167,7 @@ struct VectorIterator {
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
bool
|
||||
HasNext() {
|
||||
return !heap_.empty();
|
||||
}
|
||||
|
||||
bool
|
||||
AddIterator(knowhere::IndexNode::IteratorPtr iter) {
|
||||
if (!sealed && iter != nullptr) {
|
||||
@ -157,6 +176,7 @@ struct VectorIterator {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void
|
||||
seal() {
|
||||
sealed = true;
|
||||
@ -195,7 +215,7 @@ struct VectorIterator {
|
||||
heap_;
|
||||
bool sealed = false;
|
||||
std::vector<int64_t> total_rows_until_chunk_;
|
||||
//currently, VectorIterator is guaranteed to be used serially without concurrent problem, in the future
|
||||
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
|
||||
//we may need to add mutex to protect the variable sealed
|
||||
};
|
||||
|
||||
@ -230,15 +250,21 @@ struct SearchResult {
|
||||
for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) {
|
||||
vec_iter_idx = vec_iter_idx % nq;
|
||||
if (vector_iterators.size() < nq) {
|
||||
auto vector_iterator = std::make_shared<VectorIterator>(
|
||||
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
|
||||
chunk_count, total_rows_until_chunk);
|
||||
vector_iterators.emplace_back(vector_iterator);
|
||||
vector_iterators.emplace_back(chunk_merge_iter);
|
||||
}
|
||||
const auto& kw_iterator = kw_iterators[i];
|
||||
vector_iterators[vec_iter_idx++]->AddIterator(kw_iterator);
|
||||
auto chunk_merge_iter =
|
||||
std::static_pointer_cast<ChunkMergeIterator>(
|
||||
vector_iterators[vec_iter_idx++]);
|
||||
chunk_merge_iter->AddIterator(kw_iterator);
|
||||
}
|
||||
for (const auto& vector_iter : vector_iterators) {
|
||||
vector_iter->seal();
|
||||
// Cast to ChunkMergeIterator to call seal
|
||||
auto chunk_merge_iter =
|
||||
std::static_pointer_cast<ChunkMergeIterator>(vector_iter);
|
||||
chunk_merge_iter->seal();
|
||||
}
|
||||
this->vector_iterators_ = vector_iterators;
|
||||
}
|
||||
@ -275,6 +301,28 @@ struct SearchResult {
|
||||
vector_iterators_;
|
||||
// record the storage usage in search
|
||||
StorageCost search_storage_cost_;
|
||||
|
||||
bool element_level_{false};
|
||||
std::vector<int32_t> element_indices_;
|
||||
std::optional<std::vector<std::shared_ptr<VectorIterator>>>
|
||||
element_iterators_;
|
||||
std::shared_ptr<const IArrayOffsets> array_offsets_{nullptr};
|
||||
std::vector<std::unique_ptr<uint8_t[]>> chunk_buffers_{};
|
||||
|
||||
bool
|
||||
HasIterators() const {
|
||||
return (element_level_ && element_iterators_.has_value()) ||
|
||||
(!element_level_ && vector_iterators_.has_value());
|
||||
}
|
||||
|
||||
std::optional<std::vector<std::shared_ptr<VectorIterator>>>
|
||||
GetIterators() {
|
||||
if (element_level_) {
|
||||
return element_iterators_;
|
||||
} else {
|
||||
return vector_iterators_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using SearchResultPtr = std::shared_ptr<SearchResult>;
|
||||
|
||||
@ -171,4 +171,15 @@ Schema::MmapEnabled(const FieldId& field_id) const {
|
||||
return {true, it->second};
|
||||
}
|
||||
|
||||
const FieldMeta&
|
||||
Schema::GetFirstArrayFieldInStruct(const std::string& struct_name) const {
|
||||
auto cache_it = struct_array_field_cache_.find(struct_name);
|
||||
if (cache_it != struct_array_field_cache_.end()) {
|
||||
return fields_.at(cache_it->second);
|
||||
}
|
||||
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"No array field found in struct: {}",
|
||||
struct_name);
|
||||
}
|
||||
} // namespace milvus
|
||||
|
||||
@ -367,7 +367,24 @@ class Schema {
|
||||
name_ids_.emplace(field_name, field_id);
|
||||
id_names_.emplace(field_id, field_name);
|
||||
|
||||
fields_.emplace(field_id, field_meta);
|
||||
// Build struct_array_field_cache_ for ARRAY/VECTOR_ARRAY fields
|
||||
// Field name format: "struct_name[0].field_name"
|
||||
auto data_type = field_meta.get_data_type();
|
||||
if (data_type == DataType::ARRAY ||
|
||||
data_type == DataType::VECTOR_ARRAY) {
|
||||
const std::string& name_str = field_name.get();
|
||||
auto bracket_pos = name_str.find('[');
|
||||
if (bracket_pos != std::string::npos && bracket_pos > 0) {
|
||||
std::string struct_name = name_str.substr(0, bracket_pos);
|
||||
// Only cache the first array field for each struct
|
||||
if (struct_array_field_cache_.find(struct_name) ==
|
||||
struct_array_field_cache_.end()) {
|
||||
struct_array_field_cache_[struct_name] = field_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fields_.emplace(field_id, std::move(field_meta));
|
||||
field_ids_.emplace_back(field_id);
|
||||
}
|
||||
|
||||
@ -392,6 +409,10 @@ class Schema {
|
||||
std::pair<bool, bool>
|
||||
MmapEnabled(const FieldId& field) const;
|
||||
|
||||
// Find the first array field belonging to a struct (cached)
|
||||
const FieldMeta&
|
||||
GetFirstArrayFieldInStruct(const std::string& struct_name) const;
|
||||
|
||||
private:
|
||||
int64_t debug_id = START_USER_FIELDID;
|
||||
std::vector<FieldId> field_ids_;
|
||||
@ -418,6 +439,9 @@ class Schema {
|
||||
bool has_mmap_setting_ = false;
|
||||
bool mmap_enabled_ = false;
|
||||
std::unordered_map<FieldId, bool> mmap_fields_;
|
||||
|
||||
// Cache for struct_name -> first array field mapping (built during AddField)
|
||||
std::unordered_map<std::string, FieldId> struct_array_field_cache_;
|
||||
};
|
||||
|
||||
using SchemaPtr = std::shared_ptr<Schema>;
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include "common/Consts.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/LoadInfo.h"
|
||||
#include "common/Schema.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "knowhere/dataset.h"
|
||||
|
||||
@ -458,6 +458,11 @@ class VectorArrayView {
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
length() const {
|
||||
return length_;
|
||||
}
|
||||
|
||||
private:
|
||||
char* data_{nullptr};
|
||||
int64_t dim_ = 0;
|
||||
|
||||
@ -32,6 +32,8 @@
|
||||
#include "exec/operator/VectorSearchNode.h"
|
||||
#include "exec/operator/RandomSampleNode.h"
|
||||
#include "exec/operator/GroupByNode.h"
|
||||
#include "exec/operator/ElementFilterNode.h"
|
||||
#include "exec/operator/ElementFilterBitsNode.h"
|
||||
#include "exec/Task.h"
|
||||
#include "plan/PlanNode.h"
|
||||
|
||||
@ -104,6 +106,19 @@ DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
|
||||
tracer::AddEvent("create_operator: RescoresNode");
|
||||
operators.push_back(
|
||||
std::make_unique<PhyRescoresNode>(id, ctx.get(), rescoresnode));
|
||||
} else if (auto node =
|
||||
std::dynamic_pointer_cast<const plan::ElementFilterNode>(
|
||||
plannode)) {
|
||||
tracer::AddEvent("create_operator: ElementFilterNode");
|
||||
operators.push_back(
|
||||
std::make_unique<PhyElementFilterNode>(id, ctx.get(), node));
|
||||
} else if (auto node = std::dynamic_pointer_cast<
|
||||
const plan::ElementFilterBitsNode>(plannode)) {
|
||||
tracer::AddEvent("create_operator: ElementFilterBitsNode");
|
||||
operators.push_back(std::make_unique<PhyElementFilterBitsNode>(
|
||||
id, ctx.get(), node));
|
||||
} else {
|
||||
ThrowInfo(ErrorCode::UnexpectedError, "Unknown plan node type");
|
||||
}
|
||||
// TODO: add more operators
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include "common/Common.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Exception.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "common/OpContext.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
@ -303,6 +304,64 @@ class QueryContext : public Context {
|
||||
return plan_options_;
|
||||
}
|
||||
|
||||
void
|
||||
set_element_level_query(bool element_level) {
|
||||
element_level_query_ = element_level;
|
||||
}
|
||||
|
||||
bool
|
||||
element_level_query() const {
|
||||
return element_level_query_;
|
||||
}
|
||||
|
||||
void
|
||||
set_struct_name(const std::string& field_name) {
|
||||
struct_name_ = field_name;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_struct_name() const {
|
||||
return struct_name_;
|
||||
}
|
||||
|
||||
void
|
||||
set_array_offsets(std::shared_ptr<const IArrayOffsets> offsets) {
|
||||
array_offsets_ = std::move(offsets);
|
||||
}
|
||||
|
||||
std::shared_ptr<const IArrayOffsets>
|
||||
get_array_offsets() const {
|
||||
return array_offsets_;
|
||||
}
|
||||
|
||||
void
|
||||
set_active_element_count(int64_t count) {
|
||||
active_element_count_ = count;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_active_element_count() const {
|
||||
return active_element_count_;
|
||||
}
|
||||
|
||||
void
|
||||
set_element_level_bitset(TargetBitmap&& bitset) {
|
||||
element_level_bitset_ = std::move(bitset);
|
||||
}
|
||||
|
||||
std::optional<TargetBitmap>
|
||||
get_element_level_bitset() {
|
||||
if (element_level_bitset_.has_value()) {
|
||||
return std::move(element_level_bitset_.value());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool
|
||||
has_element_level_bitset() const {
|
||||
return element_level_bitset_.has_value();
|
||||
}
|
||||
|
||||
private:
|
||||
folly::Executor* executor_;
|
||||
//folly::Executor::KeepAlive<> executor_keepalive_;
|
||||
@ -331,6 +390,12 @@ class QueryContext : public Context {
|
||||
int32_t consistency_level_ = 0;
|
||||
|
||||
query::PlanOptions plan_options_;
|
||||
|
||||
bool element_level_query_{false};
|
||||
std::string struct_name_;
|
||||
std::shared_ptr<const IArrayOffsets> array_offsets_{nullptr};
|
||||
int64_t active_element_count_{0}; // Total elements in active documents
|
||||
std::optional<TargetBitmap> element_level_bitset_;
|
||||
};
|
||||
|
||||
// Represent the state of one thread of query execution.
|
||||
|
||||
@ -29,7 +29,11 @@ PhyBinaryArithOpEvalRangeExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
|
||||
auto input = context.get_offset_input();
|
||||
SetHasOffsetInput((input != nullptr));
|
||||
switch (expr_->column_.data_type_) {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>(input);
|
||||
break;
|
||||
@ -1840,14 +1844,27 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData(
|
||||
|
||||
int64_t processed_size;
|
||||
if (has_offset_input_) {
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
value,
|
||||
right_operand);
|
||||
if (expr_->column_.element_level_) {
|
||||
// For element-level filtering
|
||||
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
value,
|
||||
right_operand);
|
||||
} else {
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
value,
|
||||
right_operand);
|
||||
}
|
||||
} else {
|
||||
AssertInfo(!expr_->column_.element_level_,
|
||||
"Element-level filtering is not supported without offsets");
|
||||
processed_size = ProcessDataChunks<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
res,
|
||||
|
||||
@ -31,7 +31,12 @@ PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
|
||||
auto input = context.get_offset_input();
|
||||
SetHasOffsetInput((input != nullptr));
|
||||
switch (expr_->column_.data_type_) {
|
||||
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>(context);
|
||||
break;
|
||||
@ -414,14 +419,28 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData(EvalCtx& context) {
|
||||
};
|
||||
int64_t processed_size;
|
||||
if (has_offset_input_) {
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
val1,
|
||||
val2);
|
||||
if (expr_->column_.element_level_) {
|
||||
// For element-level filtering
|
||||
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
val1,
|
||||
val2);
|
||||
} else {
|
||||
// For doc-level filtering
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
val1,
|
||||
val2);
|
||||
}
|
||||
} else {
|
||||
AssertInfo(!expr_->column_.element_level_,
|
||||
"Element-level filtering is not supported without offsets");
|
||||
processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, skip_index_func, res, valid_res, val1, val2);
|
||||
}
|
||||
|
||||
@ -71,6 +71,9 @@ PhyColumnExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
template <typename T>
|
||||
VectorPtr
|
||||
PhyColumnExpr::DoEval(OffsetVector* input) {
|
||||
AssertInfo(!expr_->GetColumn().element_level_,
|
||||
"ColumnExpr of row-level access is not supported");
|
||||
|
||||
// similar to PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op)
|
||||
// take offsets as input
|
||||
if (has_offset_input_) {
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "common/Array.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "common/FieldDataInterface.h"
|
||||
#include "common/Json.h"
|
||||
#include "common/OpContext.h"
|
||||
@ -679,6 +681,173 @@ class SegmentExpr : public Expr {
|
||||
return input->size();
|
||||
}
|
||||
|
||||
// Process element-level data by element IDs
|
||||
// Handles the type mismatch between storage (ArrayView) and element type
|
||||
// Currently only implemented for sealed chunked segments
|
||||
template <typename ElementType, typename FUNC, typename... ValTypes>
|
||||
int64_t
|
||||
ProcessElementLevelByOffsets(
|
||||
FUNC func,
|
||||
std::function<bool(const milvus::SkipIndex&, FieldId, int)> skip_func,
|
||||
OffsetVector* element_ids,
|
||||
TargetBitmapView res,
|
||||
TargetBitmapView valid_res,
|
||||
const ValTypes&... values) {
|
||||
auto& skip_index = segment_->GetSkipIndex();
|
||||
if (segment_->type() == SegmentType::Sealed) {
|
||||
AssertInfo(
|
||||
segment_->is_chunked(),
|
||||
"Element-level filtering requires chunked segment for sealed");
|
||||
|
||||
auto array_offsets = segment_->GetArrayOffsets(field_id_);
|
||||
if (!array_offsets) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"IArrayOffsets not found for field {}",
|
||||
field_id_.get());
|
||||
}
|
||||
|
||||
// Batch process consecutive elements belonging to the same chunk
|
||||
size_t processed_size = 0;
|
||||
size_t i = 0;
|
||||
|
||||
// Reuse these vectors to avoid repeated heap allocations
|
||||
FixedVector<int32_t> offsets;
|
||||
FixedVector<int32_t> elem_indices;
|
||||
|
||||
while (i < element_ids->size()) {
|
||||
// Start of a new chunk batch
|
||||
int64_t element_id = (*element_ids)[i];
|
||||
auto [doc_id, elem_idx] =
|
||||
array_offsets->ElementIDToRowID(element_id);
|
||||
auto [chunk_id, chunk_offset] =
|
||||
segment_->get_chunk_by_offset(field_id_, doc_id);
|
||||
|
||||
// Collect consecutive elements belonging to the same chunk
|
||||
offsets.clear();
|
||||
elem_indices.clear();
|
||||
offsets.push_back(chunk_offset);
|
||||
elem_indices.push_back(elem_idx);
|
||||
|
||||
size_t batch_start = i;
|
||||
i++;
|
||||
|
||||
// Look ahead for more elements in the same chunk
|
||||
while (i < element_ids->size()) {
|
||||
int64_t next_element_id = (*element_ids)[i];
|
||||
auto [next_doc_id, next_elem_idx] =
|
||||
array_offsets->ElementIDToRowID(next_element_id);
|
||||
auto [next_chunk_id, next_chunk_offset] =
|
||||
segment_->get_chunk_by_offset(field_id_, next_doc_id);
|
||||
|
||||
if (next_chunk_id != chunk_id) {
|
||||
break; // Different chunk, process current batch
|
||||
}
|
||||
|
||||
offsets.push_back(next_chunk_offset);
|
||||
elem_indices.push_back(next_elem_idx);
|
||||
i++;
|
||||
}
|
||||
|
||||
// Batch fetch all ArrayViews for this chunk
|
||||
auto pw = segment_->get_views_by_offsets<ArrayView>(
|
||||
op_ctx_, field_id_, chunk_id, offsets);
|
||||
|
||||
auto [array_vec, valid_data] = pw.get();
|
||||
|
||||
// Process each element in this batch
|
||||
for (size_t j = 0; j < offsets.size(); j++) {
|
||||
size_t result_idx = batch_start + j;
|
||||
|
||||
if ((!skip_func ||
|
||||
!skip_func(skip_index, field_id_, chunk_id)) &&
|
||||
(!namespace_skip_func_.has_value() ||
|
||||
!namespace_skip_func_.value()(chunk_id))) {
|
||||
// Extract element from ArrayView
|
||||
auto value =
|
||||
array_vec[j].template get_data<ElementType>(
|
||||
elem_indices[j]);
|
||||
bool is_valid = !valid_data.data() || valid_data[j];
|
||||
|
||||
func.template operator()<FilterType::random>(
|
||||
&value,
|
||||
&is_valid,
|
||||
nullptr,
|
||||
1,
|
||||
res + result_idx,
|
||||
valid_res + result_idx,
|
||||
values...);
|
||||
} else {
|
||||
// Chunk is skipped - handle exactly like ProcessDataByOffsets
|
||||
if (valid_data.size() > j && !valid_data[j]) {
|
||||
res[result_idx] = valid_res[result_idx] = false;
|
||||
}
|
||||
}
|
||||
|
||||
processed_size++;
|
||||
}
|
||||
}
|
||||
return processed_size;
|
||||
} else {
|
||||
auto array_offsets = segment_->GetArrayOffsets(field_id_);
|
||||
if (!array_offsets) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"IArrayOffsets not found for field {}",
|
||||
field_id_.get());
|
||||
}
|
||||
|
||||
auto& skip_index = segment_->GetSkipIndex();
|
||||
size_t processed_size = 0;
|
||||
|
||||
for (size_t i = 0; i < element_ids->size(); i++) {
|
||||
int64_t element_id = (*element_ids)[i];
|
||||
|
||||
auto [doc_id, elem_idx] =
|
||||
array_offsets->ElementIDToRowID(element_id);
|
||||
|
||||
// Calculate chunk_id and chunk_offset for this doc
|
||||
auto chunk_id = doc_id / size_per_chunk_;
|
||||
auto chunk_offset = doc_id % size_per_chunk_;
|
||||
|
||||
// Get the Array chunk (Growing segment stores Array, not ArrayView)
|
||||
auto pw =
|
||||
segment_->chunk_data<Array>(op_ctx_, field_id_, chunk_id);
|
||||
auto chunk = pw.get();
|
||||
const Array* array_ptr = chunk.data() + chunk_offset;
|
||||
const bool* valid_data = chunk.valid_data();
|
||||
if (valid_data != nullptr) {
|
||||
valid_data += chunk_offset;
|
||||
}
|
||||
|
||||
if ((!skip_func ||
|
||||
!skip_func(skip_index, field_id_, chunk_id)) &&
|
||||
(!namespace_skip_func_.has_value() ||
|
||||
!namespace_skip_func_.value()(chunk_id))) {
|
||||
// Extract element from Array
|
||||
auto value = array_ptr->get_data<ElementType>(elem_idx);
|
||||
bool is_valid = !valid_data || valid_data[0];
|
||||
|
||||
func.template operator()<FilterType::random>(
|
||||
&value,
|
||||
&is_valid,
|
||||
nullptr,
|
||||
1,
|
||||
res + processed_size,
|
||||
valid_res + processed_size,
|
||||
values...);
|
||||
} else {
|
||||
// Chunk is skipped
|
||||
if (valid_data && !valid_data[0]) {
|
||||
res[processed_size] = valid_res[processed_size] = false;
|
||||
}
|
||||
}
|
||||
|
||||
processed_size++;
|
||||
}
|
||||
|
||||
return processed_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Template parameter to control whether segment offsets are needed (for GIS functions)
|
||||
template <typename T,
|
||||
bool NeedSegmentOffsets = false,
|
||||
|
||||
@ -35,7 +35,11 @@ PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
result = ExecPkTermImpl();
|
||||
return;
|
||||
}
|
||||
switch (expr_->column_.data_type_) {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecVisitorImpl<bool>(context);
|
||||
break;
|
||||
@ -1001,13 +1005,25 @@ PhyTermFilterExpr::ExecVisitorImplForData(EvalCtx& context) {
|
||||
|
||||
int64_t processed_size;
|
||||
if (has_offset_input_) {
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
arg_set_);
|
||||
if (expr_->column_.element_level_) {
|
||||
// For element-level filtering
|
||||
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
arg_set_);
|
||||
} else {
|
||||
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
|
||||
skip_index_func,
|
||||
input,
|
||||
res,
|
||||
valid_res,
|
||||
arg_set_);
|
||||
}
|
||||
} else {
|
||||
AssertInfo(!expr_->column_.element_level_,
|
||||
"Element-level filtering is not supported without offsets");
|
||||
processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, skip_index_func, res, valid_res, arg_set_);
|
||||
}
|
||||
|
||||
@ -163,7 +163,11 @@ PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
|
||||
auto input = context.get_offset_input();
|
||||
SetHasOffsetInput((input != nullptr));
|
||||
switch (expr_->column_.data_type_) {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecRangeVisitorImpl<bool>(context);
|
||||
break;
|
||||
@ -1713,9 +1717,17 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData(EvalCtx& context) {
|
||||
|
||||
int64_t processed_size;
|
||||
if (has_offset_input_) {
|
||||
processed_size = ProcessDataByOffsets<T>(
|
||||
execute_sub_batch, skip_index_func, input, res, valid_res, val);
|
||||
if (expr_->column_.element_level_) {
|
||||
// For element-level filtering
|
||||
processed_size = ProcessElementLevelByOffsets<T>(
|
||||
execute_sub_batch, skip_index_func, input, res, valid_res, val);
|
||||
} else {
|
||||
processed_size = ProcessDataByOffsets<T>(
|
||||
execute_sub_batch, skip_index_func, input, res, valid_res, val);
|
||||
}
|
||||
} else {
|
||||
AssertInfo(!expr_->column_.element_level_,
|
||||
"Element-level filtering is not supported without offsets");
|
||||
processed_size = ProcessDataChunks<T>(
|
||||
execute_sub_batch, skip_index_func, res, valid_res, val);
|
||||
}
|
||||
|
||||
215
internal/core/src/exec/operator/ElementFilterBitsNode.cpp
Normal file
215
internal/core/src/exec/operator/ElementFilterBitsNode.cpp
Normal file
@ -0,0 +1,215 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ElementFilterBitsNode.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "fmt/format.h"
|
||||
|
||||
#include "monitor/Monitor.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
PhyElementFilterBitsNode::PhyElementFilterBitsNode(
|
||||
int32_t operator_id,
|
||||
DriverContext* driverctx,
|
||||
const std::shared_ptr<const plan::ElementFilterBitsNode>&
|
||||
element_filter_bits_node)
|
||||
: Operator(driverctx,
|
||||
DataType::NONE,
|
||||
operator_id,
|
||||
"element_filter_bits_plan_node",
|
||||
"PhyElementFilterBitsNode"),
|
||||
struct_name_(element_filter_bits_node->struct_name()) {
|
||||
ExecContext* exec_context = operator_context_->get_exec_context();
|
||||
query_context_ = exec_context->get_query_context();
|
||||
|
||||
// Build expression set from element-level expression
|
||||
std::vector<expr::TypedExprPtr> exprs;
|
||||
exprs.emplace_back(element_filter_bits_node->element_filter());
|
||||
element_exprs_ = std::make_unique<ExprSet>(exprs, exec_context);
|
||||
}
|
||||
|
||||
void
|
||||
PhyElementFilterBitsNode::AddInput(RowVectorPtr& input) {
|
||||
input_ = std::move(input);
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
PhyElementFilterBitsNode::GetOutput() {
|
||||
if (is_finished_ || input_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DeferLambda([&]() { is_finished_ = true; });
|
||||
|
||||
tracer::AutoSpan span(
|
||||
"PhyElementFilterBitsNode::GetOutput", tracer::GetRootSpan(), true);
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
std::chrono::high_resolution_clock::time_point step_time;
|
||||
|
||||
// Step 1: Get array offsets
|
||||
auto segment = query_context_->get_segment();
|
||||
auto& field_meta =
|
||||
segment->get_schema().GetFirstArrayFieldInStruct(struct_name_);
|
||||
auto field_id = field_meta.get_id();
|
||||
auto array_offsets = segment->GetArrayOffsets(field_id);
|
||||
if (array_offsets == nullptr) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"IArrayOffsets not found for field {}",
|
||||
field_id.get());
|
||||
}
|
||||
query_context_->set_array_offsets(array_offsets);
|
||||
auto [first_elem, _] =
|
||||
array_offsets->ElementIDRangeOfRow(query_context_->get_active_count());
|
||||
query_context_->set_active_element_count(first_elem);
|
||||
|
||||
// Step 2: Prepare doc bitset
|
||||
auto col_input = GetColumnVector(input_);
|
||||
TargetBitmapView doc_bitset(col_input->GetRawData(), col_input->size());
|
||||
TargetBitmapView doc_bitset_valid(col_input->GetValidRawData(),
|
||||
col_input->size());
|
||||
doc_bitset.flip();
|
||||
|
||||
// Step 3: Convert doc bitset to element offsets
|
||||
FixedVector<int32_t> element_offsets =
|
||||
DocBitsetToElementOffsets(doc_bitset);
|
||||
|
||||
// Step 4: Evaluate element expression
|
||||
auto [expr_result, valid_expr_result] =
|
||||
EvaluateElementExpression(element_offsets);
|
||||
|
||||
// Step 5: Set query context
|
||||
query_context_->set_element_level_query(true);
|
||||
query_context_->set_struct_name(struct_name_);
|
||||
|
||||
std::chrono::high_resolution_clock::time_point end_time =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
double total_cost =
|
||||
std::chrono::duration<double, std::micro>(end_time - start_time)
|
||||
.count();
|
||||
milvus::monitor::internal_core_search_latency_scalar.Observe(total_cost /
|
||||
1000);
|
||||
|
||||
auto filtered_count = expr_result.count();
|
||||
tracer::AddEvent(
|
||||
fmt::format("struct_name: {}, total_elements: {}, output_rows: {}, "
|
||||
"filtered: {}, cost_us: {}",
|
||||
struct_name_,
|
||||
array_offsets->GetTotalElementCount(),
|
||||
array_offsets->GetTotalElementCount() - filtered_count,
|
||||
filtered_count,
|
||||
total_cost));
|
||||
|
||||
std::vector<VectorPtr> col_res;
|
||||
col_res.push_back(std::make_shared<ColumnVector>(
|
||||
std::move(expr_result), std::move(valid_expr_result)));
|
||||
return std::make_shared<RowVector>(col_res);
|
||||
}
|
||||
|
||||
FixedVector<int32_t>
|
||||
PhyElementFilterBitsNode::DocBitsetToElementOffsets(
|
||||
const TargetBitmapView& doc_bitset) {
|
||||
auto array_offsets = query_context_->get_array_offsets();
|
||||
AssertInfo(array_offsets != nullptr, "Array offsets not available");
|
||||
|
||||
int64_t doc_count = array_offsets->GetRowCount();
|
||||
AssertInfo(doc_bitset.size() == doc_count,
|
||||
"Doc bitset size mismatch: {} vs {}",
|
||||
doc_bitset.size(),
|
||||
doc_count);
|
||||
|
||||
FixedVector<int32_t> element_offsets;
|
||||
element_offsets.reserve(array_offsets->GetTotalElementCount());
|
||||
|
||||
// For each document that passes the filter, get all its element offsets
|
||||
for (int64_t doc_id = 0; doc_id < doc_count; ++doc_id) {
|
||||
if (doc_bitset[doc_id]) {
|
||||
// Get element range for this document
|
||||
auto [first_elem, last_elem] =
|
||||
array_offsets->ElementIDRangeOfRow(doc_id);
|
||||
|
||||
// Add all element IDs for this document
|
||||
for (int64_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
|
||||
element_offsets.push_back(static_cast<int32_t>(elem_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return element_offsets;
|
||||
}
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
PhyElementFilterBitsNode::EvaluateElementExpression(
|
||||
FixedVector<int32_t>& element_offsets) {
|
||||
tracer::AutoSpan span("PhyElementFilterBitsNode::EvaluateElementExpression",
|
||||
tracer::GetRootSpan(),
|
||||
true);
|
||||
tracer::AddEvent(fmt::format("input_elements: {}", element_offsets.size()));
|
||||
|
||||
// Use offset interface by passing element_offsets as third parameter
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(),
|
||||
element_exprs_.get(),
|
||||
&element_offsets);
|
||||
|
||||
std::vector<VectorPtr> results;
|
||||
element_exprs_->Eval(0, 1, true, eval_ctx, results);
|
||||
|
||||
AssertInfo(results.size() == 1 && results[0] != nullptr,
|
||||
"ElementFilterBitsNode: expression evaluation should return "
|
||||
"exactly one result");
|
||||
|
||||
TargetBitmap bitset;
|
||||
TargetBitmap valid_bitset;
|
||||
int64_t total_elements = query_context_->get_active_element_count();
|
||||
bitset = TargetBitmap(total_elements, false);
|
||||
valid_bitset = TargetBitmap(total_elements, true);
|
||||
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
|
||||
if (!col_vec) {
|
||||
ThrowInfo(ExprInvalid,
|
||||
"ElementFilterBitsNode result should be ColumnVector");
|
||||
}
|
||||
if (!col_vec->IsBitmap()) {
|
||||
ThrowInfo(ExprInvalid, "ElementFilterBitsNode result should be bitmap");
|
||||
}
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
|
||||
|
||||
AssertInfo(col_vec_size == element_offsets.size(),
|
||||
"ElementFilterBitsNode result size mismatch: {} vs {}",
|
||||
col_vec_size,
|
||||
element_offsets.size());
|
||||
|
||||
for (size_t i = 0; i < element_offsets.size(); ++i) {
|
||||
if (bitsetview[i]) {
|
||||
bitset[element_offsets[i]] = true;
|
||||
}
|
||||
}
|
||||
|
||||
bitset.flip();
|
||||
|
||||
tracer::AddEvent(fmt::format("evaluated_elements: {}, total_elements: {}",
|
||||
element_offsets.size(),
|
||||
total_elements));
|
||||
|
||||
return std::make_pair(std::move(bitset), std::move(valid_bitset));
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
93
internal/core/src/exec/operator/ElementFilterBitsNode.h
Normal file
93
internal/core/src/exec/operator/ElementFilterBitsNode.h
Normal file
@ -0,0 +1,93 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyElementFilterBitsNode : public Operator {
|
||||
public:
|
||||
PhyElementFilterBitsNode(
|
||||
int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
const std::shared_ptr<const plan::ElementFilterBitsNode>&
|
||||
element_filter_bits_node);
|
||||
|
||||
bool
|
||||
IsFilter() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return !input_;
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& input) override;
|
||||
|
||||
RowVectorPtr
|
||||
GetOutput() override;
|
||||
|
||||
bool
|
||||
IsFinished() override {
|
||||
return is_finished_;
|
||||
}
|
||||
|
||||
void
|
||||
Close() override {
|
||||
Operator::Close();
|
||||
if (element_exprs_) {
|
||||
element_exprs_->Clear();
|
||||
}
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
IsBlocked(ContinueFuture* /* unused */) override {
|
||||
return BlockingReason::kNotBlocked;
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "PhyElementFilterBitsNode";
|
||||
}
|
||||
|
||||
private:
|
||||
FixedVector<int32_t>
|
||||
DocBitsetToElementOffsets(const TargetBitmapView& doc_bitset);
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
EvaluateElementExpression(FixedVector<int32_t>& element_offsets);
|
||||
|
||||
std::unique_ptr<ExprSet> element_exprs_;
|
||||
QueryContext* query_context_;
|
||||
std::string struct_name_;
|
||||
bool is_finished_{false};
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
129
internal/core/src/exec/operator/ElementFilterNode.cpp
Normal file
129
internal/core/src/exec/operator/ElementFilterNode.cpp
Normal file
@ -0,0 +1,129 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ElementFilterNode.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "common/ElementFilterIterator.h"
|
||||
#include "monitor/Monitor.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
PhyElementFilterNode::PhyElementFilterNode(
|
||||
int32_t operator_id,
|
||||
DriverContext* driverctx,
|
||||
const std::shared_ptr<const plan::ElementFilterNode>& element_filter_node)
|
||||
: Operator(driverctx,
|
||||
element_filter_node->output_type(),
|
||||
operator_id,
|
||||
element_filter_node->id(),
|
||||
"PhyElementFilterNode"),
|
||||
struct_name_(element_filter_node->struct_name()) {
|
||||
ExecContext* exec_context = operator_context_->get_exec_context();
|
||||
query_context_ = exec_context->get_query_context();
|
||||
std::vector<expr::TypedExprPtr> exprs;
|
||||
exprs.emplace_back(element_filter_node->element_filter());
|
||||
element_exprs_ = std::make_unique<ExprSet>(exprs, exec_context);
|
||||
}
|
||||
|
||||
void
|
||||
PhyElementFilterNode::AddInput(RowVectorPtr& input) {
|
||||
input_ = std::move(input);
|
||||
}
|
||||
|
||||
RowVectorPtr
|
||||
PhyElementFilterNode::GetOutput() {
|
||||
if (is_finished_ || !no_more_input_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tracer::AutoSpan span(
|
||||
"PhyElementFilterNode::GetOutput", tracer::GetRootSpan(), true);
|
||||
|
||||
DeferLambda([&]() { is_finished_ = true; });
|
||||
|
||||
if (input_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
|
||||
// Step 1: Get search result with iterators
|
||||
milvus::SearchResult search_result = query_context_->get_search_result();
|
||||
|
||||
if (!search_result.element_level_) {
|
||||
ThrowInfo(ExprInvalid,
|
||||
"PhyElementFilterNode expects element-level search result");
|
||||
}
|
||||
|
||||
if (!search_result.vector_iterators_.has_value()) {
|
||||
ThrowInfo(
|
||||
ExprInvalid,
|
||||
"PhyElementFilterNode expects vector_iterators in search result");
|
||||
}
|
||||
|
||||
auto segment = query_context_->get_segment();
|
||||
auto& field_meta =
|
||||
segment->get_schema().GetFirstArrayFieldInStruct(struct_name_);
|
||||
auto field_id = field_meta.get_id();
|
||||
auto array_offsets = segment->GetArrayOffsets(field_id);
|
||||
if (array_offsets == nullptr) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"IArrayOffsets not found for field {}",
|
||||
field_id.get());
|
||||
}
|
||||
query_context_->set_array_offsets(array_offsets);
|
||||
|
||||
// Step 2: Wrap each iterator with ElementFilterIterator
|
||||
auto& base_iterators = search_result.vector_iterators_.value();
|
||||
std::vector<std::shared_ptr<VectorIterator>> wrapped_iterators;
|
||||
wrapped_iterators.reserve(base_iterators.size());
|
||||
|
||||
ExecContext* exec_context = operator_context_->get_exec_context();
|
||||
|
||||
for (auto& base_iter : base_iterators) {
|
||||
// Wrap each iterator with ElementFilterIterator
|
||||
auto wrapped_iter = std::make_shared<ElementFilterIterator>(
|
||||
base_iter, exec_context, element_exprs_.get());
|
||||
|
||||
wrapped_iterators.push_back(wrapped_iter);
|
||||
}
|
||||
|
||||
// Step 3: Update search result with wrapped iterators
|
||||
search_result.vector_iterators_ = std::move(wrapped_iterators);
|
||||
query_context_->set_search_result(std::move(search_result));
|
||||
|
||||
// Step 4: Record metrics
|
||||
std::chrono::high_resolution_clock::time_point end_time =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
double cost =
|
||||
std::chrono::duration<double, std::micro>(end_time - start_time)
|
||||
.count();
|
||||
|
||||
tracer::AddEvent(
|
||||
fmt::format("PhyElementFilterNode: wrapped {} iterators, struct_name: "
|
||||
"{}, cost_us: {}",
|
||||
wrapped_iterators.size(),
|
||||
struct_name_,
|
||||
cost));
|
||||
|
||||
// Pass through input to downstream
|
||||
return input_;
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
87
internal/core/src/exec/operator/ElementFilterNode.h
Normal file
87
internal/core/src/exec/operator/ElementFilterNode.h
Normal file
@ -0,0 +1,87 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "exec/Driver.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "exec/operator/Operator.h"
|
||||
#include "exec/QueryContext.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyElementFilterNode : public Operator {
|
||||
public:
|
||||
PhyElementFilterNode(int32_t operator_id,
|
||||
DriverContext* ctx,
|
||||
const std::shared_ptr<const plan::ElementFilterNode>&
|
||||
element_filter_node);
|
||||
|
||||
bool
|
||||
IsFilter() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
NeedInput() const override {
|
||||
return !is_finished_;
|
||||
}
|
||||
|
||||
void
|
||||
AddInput(RowVectorPtr& input) override;
|
||||
|
||||
RowVectorPtr
|
||||
GetOutput() override;
|
||||
|
||||
bool
|
||||
IsFinished() override {
|
||||
return is_finished_;
|
||||
}
|
||||
|
||||
void
|
||||
Close() override {
|
||||
Operator::Close();
|
||||
if (element_exprs_) {
|
||||
element_exprs_->Clear();
|
||||
}
|
||||
}
|
||||
|
||||
BlockingReason
|
||||
IsBlocked(ContinueFuture* /* unused */) override {
|
||||
return BlockingReason::kNotBlocked;
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return "PhyElementFilterNode";
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ExprSet> element_exprs_;
|
||||
QueryContext* query_context_;
|
||||
std::string struct_name_;
|
||||
bool is_finished_{false};
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
@ -88,7 +88,8 @@ insert_helper(milvus::SearchResult& search_result,
|
||||
const FixedVector<int32_t>& offsets,
|
||||
const int64_t nq_index,
|
||||
const int64_t unity_topk,
|
||||
const int i) {
|
||||
const int i,
|
||||
const IArrayOffsets* array_offsets = nullptr) {
|
||||
auto pos = large_is_better
|
||||
? find_binsert_position<true>(search_result.distances_,
|
||||
nq_index * unity_topk,
|
||||
@ -98,6 +99,18 @@ insert_helper(milvus::SearchResult& search_result,
|
||||
nq_index * unity_topk,
|
||||
nq_index * unity_topk + topk,
|
||||
distances[i]);
|
||||
|
||||
// For element-level: convert element_id to (doc_id, element_index)
|
||||
int64_t doc_id;
|
||||
int32_t elem_idx = -1;
|
||||
if (array_offsets != nullptr) {
|
||||
auto [doc, idx] = array_offsets->ElementIDToRowID(offsets[i]);
|
||||
doc_id = doc;
|
||||
elem_idx = idx;
|
||||
} else {
|
||||
doc_id = offsets[i];
|
||||
}
|
||||
|
||||
if (topk > pos) {
|
||||
std::memmove(&search_result.distances_[pos + 1],
|
||||
&search_result.distances_[pos],
|
||||
@ -105,8 +118,16 @@ insert_helper(milvus::SearchResult& search_result,
|
||||
std::memmove(&search_result.seg_offsets_[pos + 1],
|
||||
&search_result.seg_offsets_[pos],
|
||||
(topk - pos) * sizeof(int64_t));
|
||||
if (array_offsets != nullptr) {
|
||||
std::memmove(&search_result.element_indices_[pos + 1],
|
||||
&search_result.element_indices_[pos],
|
||||
(topk - pos) * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
search_result.seg_offsets_[pos] = doc_id;
|
||||
if (array_offsets != nullptr) {
|
||||
search_result.element_indices_[pos] = elem_idx;
|
||||
}
|
||||
search_result.seg_offsets_[pos] = offsets[i];
|
||||
search_result.distances_[pos] = distances[i];
|
||||
++topk;
|
||||
}
|
||||
@ -178,10 +199,31 @@ PhyIterativeFilterNode::GetOutput() {
|
||||
search_result.total_nq_,
|
||||
"Vector Iterators' count must be equal to total_nq_, Check "
|
||||
"your code");
|
||||
|
||||
bool element_level = search_result.element_level_;
|
||||
auto array_offsets = query_context_->get_array_offsets();
|
||||
|
||||
// For element-level, we need array_offsets to convert element_id → doc_id
|
||||
if (element_level) {
|
||||
AssertInfo(
|
||||
array_offsets != nullptr,
|
||||
"Array offsets required for element-level iterative filter");
|
||||
}
|
||||
|
||||
int nq_index = 0;
|
||||
|
||||
search_result.seg_offsets_.resize(nq * unity_topk, INVALID_SEG_OFFSET);
|
||||
search_result.distances_.resize(nq * unity_topk);
|
||||
if (element_level) {
|
||||
search_result.element_indices_.resize(nq * unity_topk, -1);
|
||||
}
|
||||
|
||||
// Reuse memory allocation across batches and nqs
|
||||
FixedVector<int32_t> doc_offsets;
|
||||
std::vector<int64_t> element_to_doc_mapping;
|
||||
std::unordered_map<int64_t, bool> doc_eval_cache;
|
||||
std::unordered_set<int64_t> unique_doc_ids;
|
||||
|
||||
for (auto& iterator : search_result.vector_iterators_.value()) {
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(),
|
||||
exprs_.get());
|
||||
@ -208,8 +250,35 @@ PhyIterativeFilterNode::GetOutput() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear but retain capacity
|
||||
doc_offsets.clear();
|
||||
element_to_doc_mapping.clear();
|
||||
doc_eval_cache.clear();
|
||||
unique_doc_ids.clear();
|
||||
|
||||
if (element_level) {
|
||||
// 1. Convert element_ids to doc_ids and do filter on those doc_ids
|
||||
// 2. element_ids with doc_ids that pass the filter are what we interested in
|
||||
element_to_doc_mapping.reserve(offsets.size());
|
||||
|
||||
for (auto element_id : offsets) {
|
||||
auto [doc_id, elem_index] =
|
||||
array_offsets->ElementIDToRowID(element_id);
|
||||
element_to_doc_mapping.push_back(doc_id);
|
||||
unique_doc_ids.insert(doc_id);
|
||||
}
|
||||
|
||||
doc_offsets.reserve(unique_doc_ids.size());
|
||||
for (auto doc_id : unique_doc_ids) {
|
||||
doc_offsets.emplace_back(static_cast<int32_t>(doc_id));
|
||||
}
|
||||
} else {
|
||||
doc_offsets = offsets;
|
||||
}
|
||||
|
||||
if (is_native_supported_) {
|
||||
eval_ctx.set_offset_input(&offsets);
|
||||
eval_ctx.set_offset_input(&doc_offsets);
|
||||
std::vector<VectorPtr> results;
|
||||
exprs_->Eval(0, 1, true, eval_ctx, results);
|
||||
AssertInfo(
|
||||
@ -223,24 +292,52 @@ PhyIterativeFilterNode::GetOutput() {
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView bitsetview(col_vec->GetRawData(),
|
||||
col_vec_size);
|
||||
Assert(bitsetview.size() <= batch_size);
|
||||
Assert(bitsetview.size() == offsets.size());
|
||||
for (auto i = 0; i < offsets.size(); ++i) {
|
||||
if (bitsetview[i] > 0) {
|
||||
insert_helper(search_result,
|
||||
topk,
|
||||
large_is_better,
|
||||
distances,
|
||||
offsets,
|
||||
nq_index,
|
||||
unity_topk,
|
||||
i);
|
||||
if (topk == unity_topk) {
|
||||
break;
|
||||
|
||||
if (element_level) {
|
||||
Assert(bitsetview.size() == doc_offsets.size());
|
||||
for (size_t i = 0; i < doc_offsets.size(); ++i) {
|
||||
doc_eval_cache[doc_offsets[i]] =
|
||||
(bitsetview[i] > 0);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < offsets.size(); ++i) {
|
||||
int64_t doc_id = element_to_doc_mapping[i];
|
||||
if (doc_eval_cache[doc_id]) {
|
||||
insert_helper(search_result,
|
||||
topk,
|
||||
large_is_better,
|
||||
distances,
|
||||
offsets,
|
||||
nq_index,
|
||||
unity_topk,
|
||||
i,
|
||||
array_offsets.get());
|
||||
if (topk == unity_topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Assert(bitsetview.size() <= batch_size);
|
||||
Assert(bitsetview.size() == offsets.size());
|
||||
for (auto i = 0; i < offsets.size(); ++i) {
|
||||
if (bitsetview[i]) {
|
||||
insert_helper(search_result,
|
||||
topk,
|
||||
large_is_better,
|
||||
distances,
|
||||
offsets,
|
||||
nq_index,
|
||||
unity_topk,
|
||||
i);
|
||||
if (topk == unity_topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Assert(!element_level);
|
||||
for (auto i = 0; i < offsets.size(); ++i) {
|
||||
if (bitset[offsets[i]] > 0) {
|
||||
insert_helper(search_result,
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
#include "VectorSearchNode.h"
|
||||
#include "common/Tracer.h"
|
||||
#include "fmt/format.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
|
||||
#include "monitor/Monitor.h"
|
||||
namespace milvus {
|
||||
@ -80,10 +82,43 @@ PhyVectorSearchNode::GetOutput() {
|
||||
auto src_data = ph.get_blob();
|
||||
auto src_offsets = ph.get_offsets();
|
||||
auto num_queries = ph.num_of_queries_;
|
||||
std::shared_ptr<const IArrayOffsets> array_offsets = nullptr;
|
||||
if (ph.element_level_) {
|
||||
array_offsets = segment_->GetArrayOffsets(search_info_.field_id_);
|
||||
AssertInfo(array_offsets != nullptr, "Array offsets not available");
|
||||
query_context_->set_array_offsets(array_offsets);
|
||||
search_info_.array_offsets_ = array_offsets;
|
||||
}
|
||||
|
||||
// There are two types of execution: pre-filter and iterative filter
|
||||
// For **pre-filter**, we have execution path: FilterBitsNode -> MvccNode -> ElementFilterBitsNode -> VectorSearchNode -> ...
|
||||
// For **iterative filter**, we have execution path: MvccNode -> VectorSearchNode -> ElementFilterNode -> FilterNode -> ...
|
||||
//
|
||||
// When embedding search embedding on embedding list is used, which means element_level_ is true, we need to transform doc-level
|
||||
// bitset to element-level bitset. In pre-filter path, ElementFilterBitsNode already transforms the bitset. We need to transform it
|
||||
// in iterative filter path.
|
||||
if (milvus::exec::UseVectorIterator(search_info_) && ph.element_level_) {
|
||||
auto col_input = GetColumnVector(input_);
|
||||
TargetBitmapView view(col_input->GetRawData(), col_input->size());
|
||||
TargetBitmapView valid_view(col_input->GetValidRawData(),
|
||||
col_input->size());
|
||||
|
||||
auto [element_bitset, valid_element_bitset] =
|
||||
array_offsets->RowBitsetToElementBitset(view, valid_view);
|
||||
|
||||
query_context_->set_active_element_count(element_bitset.size());
|
||||
|
||||
std::vector<VectorPtr> col_res;
|
||||
col_res.push_back(std::make_shared<ColumnVector>(
|
||||
std::move(element_bitset), std::move(valid_element_bitset)));
|
||||
input_ = std::make_shared<RowVector>(col_res);
|
||||
}
|
||||
|
||||
milvus::SearchResult search_result;
|
||||
|
||||
auto col_input = GetColumnVector(input_);
|
||||
TargetBitmapView view(col_input->GetRawData(), col_input->size());
|
||||
|
||||
if (view.all()) {
|
||||
query_context_->set_search_result(
|
||||
std::move(empty_search_result(num_queries)));
|
||||
@ -94,6 +129,7 @@ PhyVectorSearchNode::GetOutput() {
|
||||
milvus::BitsetView final_view((uint8_t*)col_input->GetRawData(),
|
||||
col_input->size());
|
||||
auto op_context = query_context_->get_op_context();
|
||||
// todo(SpadeA): need to pass element_level to make check more rigorously?
|
||||
segment_->vector_search(search_info_,
|
||||
src_data,
|
||||
src_offsets,
|
||||
@ -104,6 +140,7 @@ PhyVectorSearchNode::GetOutput() {
|
||||
search_result);
|
||||
|
||||
search_result.total_data_cnt_ = final_view.size();
|
||||
search_result.element_level_ = ph.element_level_;
|
||||
|
||||
span.GetSpan()->SetAttribute(
|
||||
"result_count", static_cast<int>(search_result.seg_offsets_.size()));
|
||||
|
||||
@ -118,6 +118,7 @@ struct ColumnInfo {
|
||||
DataType element_type_;
|
||||
std::vector<std::string> nested_path_;
|
||||
bool nullable_;
|
||||
bool element_level_;
|
||||
|
||||
ColumnInfo(const proto::plan::ColumnInfo& column_info)
|
||||
: field_id_(column_info.field_id()),
|
||||
@ -125,7 +126,8 @@ struct ColumnInfo {
|
||||
element_type_(static_cast<DataType>(column_info.element_type())),
|
||||
nested_path_(column_info.nested_path().begin(),
|
||||
column_info.nested_path().end()),
|
||||
nullable_(column_info.nullable()) {
|
||||
nullable_(column_info.nullable()),
|
||||
element_level_(column_info.is_element_level()) {
|
||||
}
|
||||
|
||||
ColumnInfo(FieldId field_id,
|
||||
@ -136,7 +138,8 @@ struct ColumnInfo {
|
||||
data_type_(data_type),
|
||||
element_type_(DataType::NONE),
|
||||
nested_path_(std::move(nested_path)),
|
||||
nullable_(nullable) {
|
||||
nullable_(nullable),
|
||||
element_level_(false) {
|
||||
}
|
||||
|
||||
ColumnInfo(FieldId field_id,
|
||||
@ -148,7 +151,8 @@ struct ColumnInfo {
|
||||
data_type_(data_type),
|
||||
element_type_(element_type),
|
||||
nested_path_(std::move(nested_path)),
|
||||
nullable_(nullable) {
|
||||
nullable_(nullable),
|
||||
element_level_(false) {
|
||||
}
|
||||
|
||||
bool
|
||||
@ -165,6 +169,10 @@ struct ColumnInfo {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (element_level_ != other.element_level_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nested_path_.size(); ++i) {
|
||||
if (nested_path_[i] != other.nested_path_[i]) {
|
||||
return false;
|
||||
@ -180,21 +188,25 @@ struct ColumnInfo {
|
||||
data_type_,
|
||||
element_type_,
|
||||
nested_path_,
|
||||
nullable_) < std::tie(other.field_id_,
|
||||
other.data_type_,
|
||||
other.element_type_,
|
||||
other.nested_path_,
|
||||
other.nullable_);
|
||||
nullable_,
|
||||
element_level_) < std::tie(other.field_id_,
|
||||
other.data_type_,
|
||||
other.element_type_,
|
||||
other.nested_path_,
|
||||
other.nullable_,
|
||||
other.element_level_);
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const {
|
||||
return fmt::format(
|
||||
"[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}]",
|
||||
"[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}, "
|
||||
"element_level:{}]",
|
||||
std::to_string(field_id_.get()),
|
||||
data_type_,
|
||||
element_type_,
|
||||
milvus::Join<std::string>(nested_path_, ","));
|
||||
milvus::Join<std::string>(nested_path_, ","),
|
||||
element_level_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -183,6 +183,126 @@ class FilterBitsNode : public PlanNode {
|
||||
const expr::TypedExprPtr filter_;
|
||||
};
|
||||
|
||||
class ElementFilterNode : public PlanNode {
|
||||
public:
|
||||
ElementFilterNode(const PlanNodeId& id,
|
||||
expr::TypedExprPtr element_filter,
|
||||
std::string struct_name,
|
||||
std::vector<PlanNodePtr> sources)
|
||||
: PlanNode(id),
|
||||
sources_{std::move(sources)},
|
||||
element_filter_(std::move(element_filter)),
|
||||
struct_name_(std::move(struct_name)) {
|
||||
AssertInfo(
|
||||
element_filter_->type() == DataType::BOOL,
|
||||
fmt::format(
|
||||
"Element filter expression must be of type BOOLEAN, Got {}",
|
||||
element_filter_->type()));
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return DataType::NONE;
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const expr::TypedExprPtr&
|
||||
element_filter() const {
|
||||
return element_filter_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
struct_name() const {
|
||||
return struct_name_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "ElementFilter";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format(
|
||||
"ElementFilterNode:\n\t[struct_name:{}, element_filter:{}]",
|
||||
struct_name_,
|
||||
element_filter_->ToString());
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const expr::TypedExprPtr element_filter_;
|
||||
const std::string struct_name_;
|
||||
};
|
||||
|
||||
class ElementFilterBitsNode : public PlanNode {
|
||||
public:
|
||||
ElementFilterBitsNode(
|
||||
const PlanNodeId& id,
|
||||
expr::TypedExprPtr element_filter,
|
||||
std::string struct_name,
|
||||
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
|
||||
: PlanNode(id),
|
||||
sources_{std::move(sources)},
|
||||
element_filter_(std::move(element_filter)),
|
||||
struct_name_(std::move(struct_name)) {
|
||||
AssertInfo(
|
||||
element_filter_->type() == DataType::BOOL,
|
||||
fmt::format(
|
||||
"Element filter expression must be of type BOOLEAN, Got {}",
|
||||
element_filter_->type()));
|
||||
}
|
||||
|
||||
DataType
|
||||
output_type() const override {
|
||||
return DataType::BOOL;
|
||||
}
|
||||
|
||||
std::vector<PlanNodePtr>
|
||||
sources() const override {
|
||||
return sources_;
|
||||
}
|
||||
|
||||
const expr::TypedExprPtr&
|
||||
element_filter() const {
|
||||
return element_filter_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
struct_name() const {
|
||||
return struct_name_;
|
||||
}
|
||||
|
||||
std::string_view
|
||||
name() const override {
|
||||
return "ElementFilterBits";
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format(
|
||||
"ElementFilterBitsNode:\n\t[struct_name:{}, element_filter:{}]",
|
||||
struct_name_,
|
||||
element_filter_->ToString());
|
||||
}
|
||||
|
||||
expr::ExprInfo
|
||||
GatherInfo() const override {
|
||||
expr::ExprInfo info;
|
||||
element_filter_->GatherInfo(info);
|
||||
return info;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<PlanNodePtr> sources_;
|
||||
const expr::TypedExprPtr element_filter_;
|
||||
const std::string struct_name_;
|
||||
};
|
||||
|
||||
class MvccNode : public PlanNode {
|
||||
public:
|
||||
MvccNode(const PlanNodeId& id,
|
||||
|
||||
@ -55,7 +55,12 @@ ExecPlanNodeVisitor::ExecuteTask(
|
||||
for (;;) {
|
||||
auto result = task->Next();
|
||||
if (!result) {
|
||||
Assert(processed_num == query_context->get_active_count());
|
||||
if (query_context->get_active_element_count() > 0) {
|
||||
Assert(processed_num ==
|
||||
query_context->get_active_element_count());
|
||||
} else {
|
||||
Assert(processed_num == query_context->get_active_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
auto childrens = result->childrens();
|
||||
|
||||
@ -31,29 +31,37 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
}
|
||||
|
||||
bool
|
||||
check_data_type(const FieldMeta& field_meta,
|
||||
const milvus::proto::common::PlaceholderType type) {
|
||||
check_data_type(
|
||||
const FieldMeta& field_meta,
|
||||
const milvus::proto::common::PlaceholderValue& placeholder_value) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
if (field_meta.get_element_type() == DataType::VECTOR_FLOAT) {
|
||||
return type ==
|
||||
milvus::proto::common::PlaceholderType::EmbListFloatVector;
|
||||
if (placeholder_value.element_level()) {
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::FloatVector;
|
||||
} else {
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::
|
||||
EmbListFloatVector;
|
||||
}
|
||||
} else if (field_meta.get_element_type() == DataType::VECTOR_FLOAT16) {
|
||||
return type ==
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::EmbListFloat16Vector;
|
||||
} else if (field_meta.get_element_type() == DataType::VECTOR_BFLOAT16) {
|
||||
return type == milvus::proto::common::PlaceholderType::
|
||||
EmbListBFloat16Vector;
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::
|
||||
EmbListBFloat16Vector;
|
||||
} else if (field_meta.get_element_type() == DataType::VECTOR_BINARY) {
|
||||
return type ==
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::EmbListBinaryVector;
|
||||
} else if (field_meta.get_element_type() == DataType::VECTOR_INT8) {
|
||||
return type ==
|
||||
return placeholder_value.type() ==
|
||||
milvus::proto::common::PlaceholderType::EmbListInt8Vector;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return static_cast<int>(field_meta.get_data_type()) ==
|
||||
static_cast<int>(type);
|
||||
static_cast<int>(placeholder_value.type());
|
||||
}
|
||||
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
@ -64,30 +72,33 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
milvus::proto::common::PlaceholderGroup ph_group;
|
||||
auto ok = ph_group.ParseFromArray(blob, blob_len);
|
||||
Assert(ok);
|
||||
for (auto& info : ph_group.placeholders()) {
|
||||
for (auto& ph : ph_group.placeholders()) {
|
||||
Placeholder element;
|
||||
element.tag_ = info.tag();
|
||||
element.tag_ = ph.tag();
|
||||
element.element_level_ = ph.element_level();
|
||||
Assert(plan->tag2field_.count(element.tag_));
|
||||
auto field_id = plan->tag2field_.at(element.tag_);
|
||||
auto& field_meta = plan->schema_->operator[](field_id);
|
||||
AssertInfo(check_data_type(field_meta, info.type()),
|
||||
AssertInfo(check_data_type(field_meta, ph),
|
||||
"vector type must be the same, field {} - type {}, search "
|
||||
"info type {}",
|
||||
"ph type {}",
|
||||
field_meta.get_name().get(),
|
||||
field_meta.get_data_type(),
|
||||
static_cast<DataType>(info.type()));
|
||||
element.num_of_queries_ = info.values_size();
|
||||
static_cast<DataType>(ph.type()));
|
||||
element.num_of_queries_ = ph.values_size();
|
||||
AssertInfo(element.num_of_queries_ > 0, "must have queries");
|
||||
if (info.type() ==
|
||||
if (ph.type() ==
|
||||
milvus::proto::common::PlaceholderType::SparseFloatVector) {
|
||||
element.sparse_matrix_ =
|
||||
SparseBytesToRows(info.values(), /*validate=*/true);
|
||||
SparseBytesToRows(ph.values(), /*validate=*/true);
|
||||
} else {
|
||||
auto line_size = info.values().Get(0).size();
|
||||
auto line_size = ph.values().Get(0).size();
|
||||
auto& target = element.blob_;
|
||||
|
||||
if (field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
|
||||
if (field_meta.get_sizeof() != line_size) {
|
||||
if (field_meta.get_data_type() != DataType::VECTOR_ARRAY ||
|
||||
ph.element_level()) {
|
||||
if (field_meta.get_sizeof() != line_size &&
|
||||
!ph.element_level()) {
|
||||
ThrowInfo(DimNotMatch,
|
||||
fmt::format(
|
||||
"vector dimension mismatch, expected vector "
|
||||
@ -96,7 +107,7 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
line_size));
|
||||
}
|
||||
target.reserve(line_size * element.num_of_queries_);
|
||||
for (auto& line : info.values()) {
|
||||
for (auto& line : ph.values()) {
|
||||
AssertInfo(line_size == line.size(),
|
||||
"vector dimension mismatch, expected vector "
|
||||
"size(byte) {}, actual {}.",
|
||||
@ -118,7 +129,7 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
|
||||
auto bytes_per_vec = milvus::vector_bytes_per_element(
|
||||
field_meta.get_element_type(), dim);
|
||||
for (auto& line : info.values()) {
|
||||
for (auto& line : ph.values()) {
|
||||
target.insert(target.end(), line.begin(), line.end());
|
||||
AssertInfo(
|
||||
line.size() % bytes_per_vec == 0,
|
||||
|
||||
@ -85,6 +85,7 @@ struct Placeholder {
|
||||
sparse_matrix_;
|
||||
// offsets for embedding list
|
||||
aligned_vector<size_t> offsets_;
|
||||
bool element_level_{false};
|
||||
|
||||
const void*
|
||||
get_blob() const {
|
||||
|
||||
@ -63,162 +63,201 @@ MergeExprWithNamespace(const SchemaPtr schema,
|
||||
return and_expr;
|
||||
}
|
||||
|
||||
SearchInfo
|
||||
ProtoParser::ParseSearchInfo(const planpb::VectorANNS& anns_proto) {
|
||||
SearchInfo search_info;
|
||||
auto& query_info_proto = anns_proto.query_info();
|
||||
auto field_id = FieldId(anns_proto.field_id());
|
||||
search_info.field_id_ = field_id;
|
||||
|
||||
search_info.metric_type_ = query_info_proto.metric_type();
|
||||
search_info.topk_ = query_info_proto.topk();
|
||||
search_info.round_decimal_ = query_info_proto.round_decimal();
|
||||
search_info.search_params_ =
|
||||
nlohmann::json::parse(query_info_proto.search_params());
|
||||
search_info.materialized_view_involved =
|
||||
query_info_proto.materialized_view_involved();
|
||||
// currently, iterative filter does not support range search
|
||||
if (!search_info.search_params_.contains(RADIUS)) {
|
||||
if (query_info_proto.hints() != "") {
|
||||
if (query_info_proto.hints() == "disable") {
|
||||
search_info.iterative_filter_execution = false;
|
||||
} else if (query_info_proto.hints() == ITERATIVE_FILTER) {
|
||||
search_info.iterative_filter_execution = true;
|
||||
} else {
|
||||
// check if hints is valid
|
||||
ThrowInfo(ConfigInvalid,
|
||||
"hints: {} not supported",
|
||||
query_info_proto.hints());
|
||||
}
|
||||
} else if (search_info.search_params_.contains(HINTS)) {
|
||||
if (search_info.search_params_[HINTS] == ITERATIVE_FILTER) {
|
||||
search_info.iterative_filter_execution = true;
|
||||
} else {
|
||||
// check if hints is valid
|
||||
ThrowInfo(ConfigInvalid,
|
||||
"hints: {} not supported",
|
||||
search_info.search_params_[HINTS]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (query_info_proto.bm25_avgdl() > 0) {
|
||||
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||
query_info_proto.bm25_avgdl();
|
||||
}
|
||||
|
||||
if (query_info_proto.group_by_field_id() > 0) {
|
||||
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
|
||||
search_info.group_by_field_id_ = group_by_field_id;
|
||||
search_info.group_size_ = query_info_proto.group_size() > 0
|
||||
? query_info_proto.group_size()
|
||||
: 1;
|
||||
search_info.strict_group_size_ = query_info_proto.strict_group_size();
|
||||
// Always set json_path to distinguish between unset and empty string
|
||||
// Empty string means accessing the entire JSON object
|
||||
search_info.json_path_ = query_info_proto.json_path();
|
||||
if (query_info_proto.json_type() !=
|
||||
milvus::proto::schema::DataType::None) {
|
||||
search_info.json_type_ =
|
||||
static_cast<milvus::DataType>(query_info_proto.json_type());
|
||||
}
|
||||
search_info.strict_cast_ = query_info_proto.strict_cast();
|
||||
}
|
||||
|
||||
if (query_info_proto.has_search_iterator_v2_info()) {
|
||||
auto& iterator_v2_info_proto =
|
||||
query_info_proto.search_iterator_v2_info();
|
||||
search_info.iterator_v2_info_ = SearchIteratorV2Info{
|
||||
.token = iterator_v2_info_proto.token(),
|
||||
.batch_size = iterator_v2_info_proto.batch_size(),
|
||||
};
|
||||
if (iterator_v2_info_proto.has_last_bound()) {
|
||||
search_info.iterator_v2_info_->last_bound =
|
||||
iterator_v2_info_proto.last_bound();
|
||||
}
|
||||
}
|
||||
|
||||
return search_info;
|
||||
}
|
||||
|
||||
std::unique_ptr<VectorPlanNode>
|
||||
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||
Assert(plan_node_proto.has_vector_anns());
|
||||
auto& anns_proto = plan_node_proto.vector_anns();
|
||||
|
||||
auto expr_parser = [&]() -> plan::PlanNodePtr {
|
||||
auto expr = ParseExprs(anns_proto.predicates());
|
||||
if (plan_node_proto.has_namespace_()) {
|
||||
expr = MergeExprWithNamespace(
|
||||
schema, expr, plan_node_proto.namespace_());
|
||||
}
|
||||
return std::make_shared<plan::FilterBitsNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), expr);
|
||||
};
|
||||
|
||||
auto search_info_parser = [&]() -> SearchInfo {
|
||||
SearchInfo search_info;
|
||||
auto& query_info_proto = anns_proto.query_info();
|
||||
auto field_id = FieldId(anns_proto.field_id());
|
||||
search_info.field_id_ = field_id;
|
||||
|
||||
search_info.metric_type_ = query_info_proto.metric_type();
|
||||
search_info.topk_ = query_info_proto.topk();
|
||||
search_info.round_decimal_ = query_info_proto.round_decimal();
|
||||
search_info.search_params_ =
|
||||
nlohmann::json::parse(query_info_proto.search_params());
|
||||
search_info.materialized_view_involved =
|
||||
query_info_proto.materialized_view_involved();
|
||||
// currently, iterative filter does not support range search
|
||||
if (!search_info.search_params_.contains(RADIUS)) {
|
||||
if (query_info_proto.hints() != "") {
|
||||
if (query_info_proto.hints() == "disable") {
|
||||
search_info.iterative_filter_execution = false;
|
||||
} else if (query_info_proto.hints() == ITERATIVE_FILTER) {
|
||||
search_info.iterative_filter_execution = true;
|
||||
} else {
|
||||
// check if hints is valid
|
||||
ThrowInfo(ConfigInvalid,
|
||||
"hints: {} not supported",
|
||||
query_info_proto.hints());
|
||||
}
|
||||
} else if (search_info.search_params_.contains(HINTS)) {
|
||||
if (search_info.search_params_[HINTS] == ITERATIVE_FILTER) {
|
||||
search_info.iterative_filter_execution = true;
|
||||
} else {
|
||||
// check if hints is valid
|
||||
ThrowInfo(ConfigInvalid,
|
||||
"hints: {} not supported",
|
||||
search_info.search_params_[HINTS]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (query_info_proto.bm25_avgdl() > 0) {
|
||||
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||
query_info_proto.bm25_avgdl();
|
||||
}
|
||||
|
||||
if (query_info_proto.group_by_field_id() > 0) {
|
||||
auto group_by_field_id =
|
||||
FieldId(query_info_proto.group_by_field_id());
|
||||
search_info.group_by_field_id_ = group_by_field_id;
|
||||
search_info.group_size_ = query_info_proto.group_size() > 0
|
||||
? query_info_proto.group_size()
|
||||
: 1;
|
||||
search_info.strict_group_size_ =
|
||||
query_info_proto.strict_group_size();
|
||||
// Always set json_path to distinguish between unset and empty string
|
||||
// Empty string means accessing the entire JSON object
|
||||
search_info.json_path_ = query_info_proto.json_path();
|
||||
if (query_info_proto.json_type() !=
|
||||
milvus::proto::schema::DataType::None) {
|
||||
search_info.json_type_ =
|
||||
static_cast<milvus::DataType>(query_info_proto.json_type());
|
||||
}
|
||||
search_info.strict_cast_ = query_info_proto.strict_cast();
|
||||
}
|
||||
|
||||
if (query_info_proto.has_search_iterator_v2_info()) {
|
||||
auto& iterator_v2_info_proto =
|
||||
query_info_proto.search_iterator_v2_info();
|
||||
search_info.iterator_v2_info_ = SearchIteratorV2Info{
|
||||
.token = iterator_v2_info_proto.token(),
|
||||
.batch_size = iterator_v2_info_proto.batch_size(),
|
||||
};
|
||||
if (iterator_v2_info_proto.has_last_bound()) {
|
||||
search_info.iterator_v2_info_->last_bound =
|
||||
iterator_v2_info_proto.last_bound();
|
||||
}
|
||||
}
|
||||
|
||||
return search_info;
|
||||
};
|
||||
|
||||
// Parse search information from proto
|
||||
auto plan_node = std::make_unique<VectorPlanNode>();
|
||||
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
|
||||
plan_node->search_info_ = std::move(search_info_parser());
|
||||
plan_node->search_info_ = ParseSearchInfo(anns_proto);
|
||||
|
||||
milvus::plan::PlanNodePtr plannode;
|
||||
std::vector<milvus::plan::PlanNodePtr> sources;
|
||||
|
||||
// mvcc node -> vector search node -> iterative filter node
|
||||
auto iterative_filter_plan = [&]() {
|
||||
plannode = std::make_shared<milvus::plan::MvccNode>(
|
||||
milvus::plan::GetNextPlanNodeId());
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
auto expr = ParseExprs(anns_proto.predicates());
|
||||
if (plan_node_proto.has_namespace_()) {
|
||||
expr = MergeExprWithNamespace(
|
||||
schema, expr, plan_node_proto.namespace_());
|
||||
}
|
||||
plannode = std::make_shared<plan::FilterNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), expr, sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
};
|
||||
|
||||
// pre filter node -> mvcc node -> vector search node
|
||||
auto pre_filter_plan = [&]() {
|
||||
plannode = std::move(expr_parser());
|
||||
if (plan_node->search_info_.materialized_view_involved) {
|
||||
const auto expr_info = plannode->GatherInfo();
|
||||
knowhere::MaterializedViewSearchInfo materialized_view_search_info;
|
||||
for (const auto& [expr_field_id, vals] :
|
||||
expr_info.field_id_to_values) {
|
||||
materialized_view_search_info
|
||||
.field_id_to_touched_categories_cnt[expr_field_id] =
|
||||
vals.size();
|
||||
}
|
||||
materialized_view_search_info.is_pure_and = expr_info.is_pure_and;
|
||||
materialized_view_search_info.has_not = expr_info.has_not;
|
||||
|
||||
plan_node->search_info_
|
||||
.search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
|
||||
materialized_view_search_info;
|
||||
}
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
plannode = std::make_shared<milvus::plan::MvccNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
};
|
||||
|
||||
// Build plan node chain based on predicate and filter execution strategy
|
||||
if (anns_proto.has_predicates()) {
|
||||
// currently limit iterative filter scope to search only
|
||||
if (plan_node->search_info_.iterative_filter_execution &&
|
||||
plan_node->search_info_.group_by_field_id_ == std::nullopt) {
|
||||
iterative_filter_plan();
|
||||
auto* predicate_proto = &anns_proto.predicates();
|
||||
bool is_element_level = predicate_proto->expr_case() ==
|
||||
proto::plan::Expr::kElementFilterExpr;
|
||||
|
||||
// Parse expressions based on filter type (similar to RandomSampleExpr pattern)
|
||||
expr::TypedExprPtr element_expr = nullptr;
|
||||
expr::TypedExprPtr doc_expr = nullptr;
|
||||
std::string struct_name;
|
||||
|
||||
if (is_element_level) {
|
||||
// Element-level query: extract both element_expr and optional doc-level predicate
|
||||
auto& element_filter_expr = predicate_proto->element_filter_expr();
|
||||
element_expr = ParseExprs(element_filter_expr.element_expr());
|
||||
struct_name = element_filter_expr.struct_name();
|
||||
|
||||
if (element_filter_expr.has_predicate()) {
|
||||
doc_expr = ParseExprs(element_filter_expr.predicate());
|
||||
if (plan_node_proto.has_namespace_()) {
|
||||
doc_expr = MergeExprWithNamespace(
|
||||
schema, doc_expr, plan_node_proto.namespace_());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pre_filter_plan();
|
||||
// Document-level query: only doc expr
|
||||
doc_expr = ParseExprs(anns_proto.predicates());
|
||||
if (plan_node_proto.has_namespace_()) {
|
||||
doc_expr = MergeExprWithNamespace(
|
||||
schema, doc_expr, plan_node_proto.namespace_());
|
||||
}
|
||||
}
|
||||
|
||||
bool is_iterative =
|
||||
plan_node->search_info_.iterative_filter_execution &&
|
||||
plan_node->search_info_.group_by_field_id_ == std::nullopt;
|
||||
if (is_iterative) {
|
||||
plannode = std::make_shared<milvus::plan::MvccNode>(
|
||||
milvus::plan::GetNextPlanNodeId());
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
// Add element-level filter if needed
|
||||
if (is_element_level) {
|
||||
plannode = std::make_shared<plan::ElementFilterNode>(
|
||||
milvus::plan::GetNextPlanNodeId(),
|
||||
element_expr,
|
||||
struct_name,
|
||||
sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
|
||||
// Add doc-level filter if present
|
||||
if (doc_expr) {
|
||||
plannode = std::make_shared<plan::FilterNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), doc_expr, sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
} else {
|
||||
if (doc_expr) {
|
||||
plannode = std::make_shared<plan::FilterBitsNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), doc_expr);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
if (!is_element_level &&
|
||||
plan_node->search_info_.materialized_view_involved) {
|
||||
const auto expr_info = plannode->GatherInfo();
|
||||
knowhere::MaterializedViewSearchInfo
|
||||
materialized_view_search_info;
|
||||
for (const auto& [expr_field_id, vals] :
|
||||
expr_info.field_id_to_values) {
|
||||
materialized_view_search_info
|
||||
.field_id_to_touched_categories_cnt[expr_field_id] =
|
||||
vals.size();
|
||||
}
|
||||
materialized_view_search_info.is_pure_and =
|
||||
expr_info.is_pure_and;
|
||||
materialized_view_search_info.has_not = expr_info.has_not;
|
||||
|
||||
plan_node->search_info_.search_params_
|
||||
[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
|
||||
materialized_view_search_info;
|
||||
}
|
||||
}
|
||||
|
||||
plannode = std::make_shared<milvus::plan::MvccNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
|
||||
if (is_element_level) {
|
||||
plannode = std::make_shared<plan::ElementFilterBitsNode>(
|
||||
milvus::plan::GetNextPlanNodeId(),
|
||||
element_expr,
|
||||
struct_name,
|
||||
sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
|
||||
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
|
||||
milvus::plan::GetNextPlanNodeId(), sources);
|
||||
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
|
||||
}
|
||||
} else {
|
||||
// no filter, force set iterative filter hint to false, go with normal vector search path
|
||||
@ -400,8 +439,16 @@ expr::TypedExprPtr
|
||||
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
std::vector<::milvus::proto::plan::GenericValue> extra_values;
|
||||
for (auto val : expr_pb.extra_values()) {
|
||||
extra_values.emplace_back(val);
|
||||
@ -417,8 +464,16 @@ expr::TypedExprPtr
|
||||
ProtoParser::ParseNullExprs(const proto::plan::NullExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
return std::make_shared<milvus::expr::NullExpr>(
|
||||
expr::ColumnInfo(column_info), expr_pb.op());
|
||||
}
|
||||
@ -428,8 +483,15 @@ ProtoParser::ParseBinaryRangeExprs(
|
||||
const proto::plan::BinaryRangeExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
return std::make_shared<expr::BinaryRangeFilterExpr>(
|
||||
columnInfo,
|
||||
expr_pb.lower_value(),
|
||||
@ -443,8 +505,15 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
|
||||
const proto::plan::TimestamptzArithCompareExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.timestamptz_column();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
return std::make_shared<expr::TimestamptzArithCompareExpr>(
|
||||
columnInfo,
|
||||
expr_pb.arith_op(),
|
||||
@ -453,6 +522,17 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
|
||||
expr_pb.compare_value());
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseElementFilterExprs(
|
||||
const proto::plan::ElementFilterExpr& expr_pb) {
|
||||
// ElementFilterExpr is not a regular expression that can be evaluated directly.
|
||||
// It should be handled at the PlanNode level (in PlanNodeFromProto).
|
||||
// This method should never be called.
|
||||
ThrowInfo(ExprInvalid,
|
||||
"ParseElementFilterExprs should not be called directly. "
|
||||
"ElementFilterExpr must be handled at PlanNode level.");
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
|
||||
std::vector<expr::TypedExprPtr> parameters;
|
||||
@ -480,15 +560,31 @@ expr::TypedExprPtr
|
||||
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
|
||||
auto& left_column_info = expr_pb.left_column_info();
|
||||
auto left_field_id = FieldId(left_column_info.field_id());
|
||||
auto left_data_type = schema->operator[](left_field_id).get_data_type();
|
||||
Assert(left_data_type ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
auto& left_field = schema->operator[](left_field_id);
|
||||
auto left_data_type = left_field.get_data_type();
|
||||
|
||||
if (left_column_info.is_element_level()) {
|
||||
Assert(left_data_type == DataType::ARRAY);
|
||||
Assert(left_field.get_element_type() ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
} else {
|
||||
Assert(left_data_type ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
}
|
||||
|
||||
auto& right_column_info = expr_pb.right_column_info();
|
||||
auto right_field_id = FieldId(right_column_info.field_id());
|
||||
auto right_data_type = schema->operator[](right_field_id).get_data_type();
|
||||
Assert(right_data_type ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
auto& right_field = schema->operator[](right_field_id);
|
||||
auto right_data_type = right_field.get_data_type();
|
||||
|
||||
if (right_column_info.is_element_level()) {
|
||||
Assert(right_data_type == DataType::ARRAY);
|
||||
Assert(right_field.get_element_type() ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
} else {
|
||||
Assert(right_data_type ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
}
|
||||
|
||||
return std::make_shared<expr::CompareExpr>(left_field_id,
|
||||
right_field_id,
|
||||
@ -501,8 +597,15 @@ expr::TypedExprPtr
|
||||
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
std::vector<::milvus::proto::plan::GenericValue> values;
|
||||
for (size_t i = 0; i < expr_pb.values_size(); i++) {
|
||||
values.emplace_back(expr_pb.values(i));
|
||||
@ -532,8 +635,16 @@ ProtoParser::ParseBinaryArithOpEvalRangeExprs(
|
||||
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.column_info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
|
||||
column_info,
|
||||
expr_pb.op(),
|
||||
@ -546,8 +657,16 @@ expr::TypedExprPtr
|
||||
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
|
||||
auto& column_info = expr_pb.info();
|
||||
auto field_id = FieldId(column_info.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
return std::make_shared<expr::ExistsExpr>(column_info);
|
||||
}
|
||||
|
||||
@ -556,8 +675,15 @@ ProtoParser::ParseJsonContainsExprs(
|
||||
const proto::plan::JSONContainsExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
std::vector<::milvus::proto::plan::GenericValue> values;
|
||||
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
|
||||
values.emplace_back(expr_pb.elements(i));
|
||||
@ -584,8 +710,15 @@ ProtoParser::ParseGISFunctionFilterExprs(
|
||||
const proto::plan::GISFunctionFilterExpr& expr_pb) {
|
||||
auto& columnInfo = expr_pb.column_info();
|
||||
auto field_id = FieldId(columnInfo.field_id());
|
||||
auto data_type = schema->operator[](field_id).get_data_type();
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
auto& field = schema->operator[](field_id);
|
||||
auto data_type = field.get_data_type();
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
|
||||
auto expr = std::make_shared<expr::GISFunctionFilterExpr>(
|
||||
columnInfo, expr_pb.op(), expr_pb.wkt_string(), expr_pb.distance());
|
||||
@ -671,6 +804,11 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
|
||||
expr_pb.timestamptz_arith_compare_expr());
|
||||
break;
|
||||
}
|
||||
case ppe::kElementFilterExpr: {
|
||||
ThrowInfo(ExprInvalid,
|
||||
"ElementFilterExpr should be handled at PlanNode level, "
|
||||
"not in ParseExprs");
|
||||
}
|
||||
default: {
|
||||
std::string s;
|
||||
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
|
||||
|
||||
@ -106,6 +106,9 @@ class ProtoParser {
|
||||
ParseTimestamptzArithCompareExprs(
|
||||
const proto::plan::TimestamptzArithCompareExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseElementFilterExprs(const proto::plan::ElementFilterExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseValueExprs(const proto::plan::ValueExpr& expr_pb);
|
||||
|
||||
@ -113,6 +116,9 @@ class ProtoParser {
|
||||
PlanOptionsFromProto(const proto::plan::PlanOption& plan_option_proto,
|
||||
PlanOptions& plan_options);
|
||||
|
||||
SearchInfo
|
||||
ParseSearchInfo(const proto::plan::VectorANNS& anns_proto);
|
||||
|
||||
private:
|
||||
const SchemaPtr schema;
|
||||
};
|
||||
|
||||
@ -140,7 +140,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
// not gurantee to return exactly `range_search_k` results, which may be more or less.
|
||||
// set it to -1 will return all results in the range.
|
||||
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;
|
||||
sub_result.mutable_seg_offsets().resize(nq * topk);
|
||||
sub_result.mutable_offsets().resize(nq * topk);
|
||||
sub_result.mutable_distances().resize(nq * topk);
|
||||
|
||||
// For vector array (embedding list), element type is used to determine how to operate search.
|
||||
@ -196,8 +196,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
auto result =
|
||||
ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type);
|
||||
milvus::tracer::AddEvent("ReGenRangeSearchResult");
|
||||
std::copy_n(
|
||||
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
|
||||
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_offsets());
|
||||
std::copy_n(
|
||||
GetDatasetDistance(result), nq * topk, sub_result.get_distances());
|
||||
} else {
|
||||
@ -206,7 +205,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchWithBuf<float>(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
@ -215,7 +214,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchWithBuf<float16>(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
@ -224,7 +223,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchWithBuf<bfloat16>(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
@ -233,7 +232,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchWithBuf<bin1>(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
@ -242,7 +241,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchSparseWithBuf(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
@ -251,7 +250,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
stat = knowhere::BruteForce::SearchWithBuf<int8>(
|
||||
base_dataset,
|
||||
query_dataset,
|
||||
sub_result.mutable_seg_offsets().data(),
|
||||
sub_result.mutable_offsets().data(),
|
||||
sub_result.mutable_distances().data(),
|
||||
search_cfg,
|
||||
bitset,
|
||||
|
||||
@ -129,7 +129,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
nullptr);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
auto ans = result.get_offsets() + i * topk;
|
||||
AssertMatch(ref, ans);
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = RangeSearchRef(
|
||||
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);
|
||||
auto ans = result2.get_seg_offsets() + i * topk;
|
||||
auto ans = result2.get_offsets() + i * topk;
|
||||
AssertMatch(ref, ans);
|
||||
}
|
||||
|
||||
|
||||
@ -151,7 +151,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
||||
dim,
|
||||
topk,
|
||||
metric_type);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
auto ans = result.get_offsets() + i * topk;
|
||||
AssertMatch(ref, ans);
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
SearchResult& search_result) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& record = segment.get_insert_record();
|
||||
auto active_count =
|
||||
auto active_row_count =
|
||||
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
|
||||
|
||||
// step 1.1: get meta
|
||||
@ -162,7 +162,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
|
||||
CachedSearchIterator cached_iter(search_dataset,
|
||||
vec_ptr,
|
||||
active_count,
|
||||
active_row_count,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
@ -172,23 +172,30 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
}
|
||||
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
|
||||
auto max_chunk = upper_div(active_row_count, vec_size_per_chunk);
|
||||
|
||||
// embedding search embedding on embedding list
|
||||
bool embedding_search = false;
|
||||
if (data_type == DataType::VECTOR_ARRAY &&
|
||||
info.array_offsets_ != nullptr) {
|
||||
embedding_search = true;
|
||||
}
|
||||
|
||||
for (int chunk_id = current_chunk_id; chunk_id < max_chunk;
|
||||
++chunk_id) {
|
||||
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);
|
||||
|
||||
auto element_begin = chunk_id * vec_size_per_chunk;
|
||||
auto element_end =
|
||||
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = element_end - element_begin;
|
||||
auto row_begin = chunk_id * vec_size_per_chunk;
|
||||
auto row_end =
|
||||
std::min(active_row_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = row_end - row_begin;
|
||||
|
||||
query::dataset::RawDataset sub_data;
|
||||
std::unique_ptr<uint8_t[]> buf = nullptr;
|
||||
std::vector<size_t> offsets;
|
||||
if (data_type != DataType::VECTOR_ARRAY) {
|
||||
sub_data = query::dataset::RawDataset{
|
||||
element_begin, dim, size_per_chunk, chunk_data};
|
||||
row_begin, dim, size_per_chunk, chunk_data};
|
||||
} else {
|
||||
// TODO(SpadeA): For VectorArray(Embedding List), data is
|
||||
// discreted stored in FixedVector which means we will copy the
|
||||
@ -201,43 +208,59 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
}
|
||||
|
||||
buf = std::make_unique<uint8_t[]>(size);
|
||||
offsets.reserve(size_per_chunk + 1);
|
||||
offsets.push_back(0);
|
||||
|
||||
auto offset = 0;
|
||||
auto ptr = buf.get();
|
||||
for (int i = 0; i < size_per_chunk; ++i) {
|
||||
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
|
||||
ptr += vec_ptr[i].byte_size();
|
||||
if (embedding_search) {
|
||||
auto count = 0;
|
||||
auto ptr = buf.get();
|
||||
for (int i = 0; i < size_per_chunk; ++i) {
|
||||
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
|
||||
ptr += vec_ptr[i].byte_size();
|
||||
count += vec_ptr[i].length();
|
||||
}
|
||||
sub_data = query::dataset::RawDataset{
|
||||
row_begin, dim, count, buf.get()};
|
||||
} else {
|
||||
offsets.reserve(size_per_chunk + 1);
|
||||
offsets.push_back(0);
|
||||
|
||||
offset += vec_ptr[i].length();
|
||||
offsets.push_back(offset);
|
||||
auto offset = 0;
|
||||
auto ptr = buf.get();
|
||||
for (int i = 0; i < size_per_chunk; ++i) {
|
||||
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
|
||||
ptr += vec_ptr[i].byte_size();
|
||||
|
||||
offset += vec_ptr[i].length();
|
||||
offsets.push_back(offset);
|
||||
}
|
||||
sub_data = query::dataset::RawDataset{row_begin,
|
||||
dim,
|
||||
size_per_chunk,
|
||||
buf.get(),
|
||||
offsets.data()};
|
||||
}
|
||||
sub_data = query::dataset::RawDataset{element_begin,
|
||||
dim,
|
||||
size_per_chunk,
|
||||
buf.get(),
|
||||
offsets.data()};
|
||||
}
|
||||
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
AssertInfo(
|
||||
query_offsets != nullptr,
|
||||
"query_offsets is nullptr, but data_type is vector array");
|
||||
auto vector_type = data_type;
|
||||
if (embedding_search) {
|
||||
vector_type = element_type;
|
||||
}
|
||||
|
||||
if (milvus::exec::UseVectorIterator(info)) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
AssertInfo(vector_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"vector iterator");
|
||||
|
||||
if (buf != nullptr) {
|
||||
search_result.chunk_buffers_.emplace_back(std::move(buf));
|
||||
}
|
||||
|
||||
auto sub_qr =
|
||||
PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
|
||||
sub_data,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
vector_type);
|
||||
final_qr.merge(sub_qr);
|
||||
} else {
|
||||
auto sub_qr = BruteForceSearch(search_dataset,
|
||||
@ -245,7 +268,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type,
|
||||
vector_type,
|
||||
element_type,
|
||||
op_context);
|
||||
final_qr.merge(sub_qr);
|
||||
@ -259,9 +282,18 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
search_result.AssembleChunkVectorIterators(
|
||||
num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators());
|
||||
} else {
|
||||
if (info.array_offsets_ != nullptr) {
|
||||
auto [seg_offsets, elem_indicies] =
|
||||
final_qr.convert_to_element_offsets(
|
||||
info.array_offsets_.get());
|
||||
search_result.seg_offsets_ = std::move(seg_offsets);
|
||||
search_result.element_indices_ = std::move(elem_indicies);
|
||||
search_result.element_level_ = true;
|
||||
} else {
|
||||
search_result.seg_offsets_ =
|
||||
std::move(final_qr.mutable_offsets());
|
||||
}
|
||||
search_result.distances_ = std::move(final_qr.mutable_distances());
|
||||
search_result.seg_offsets_ =
|
||||
std::move(final_qr.mutable_seg_offsets());
|
||||
}
|
||||
search_result.unity_topK_ = topk;
|
||||
search_result.total_nq_ = num_queries;
|
||||
|
||||
@ -96,6 +96,28 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
std::round(distances[i] * multiplier) / multiplier;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle element-level conversion if needed
|
||||
if (search_info.array_offsets_ != nullptr) {
|
||||
std::vector<int64_t> element_ids =
|
||||
std::move(search_result.seg_offsets_);
|
||||
search_result.seg_offsets_.resize(element_ids.size());
|
||||
search_result.element_indices_.resize(element_ids.size());
|
||||
|
||||
for (size_t i = 0; i < element_ids.size(); i++) {
|
||||
if (element_ids[i] == INVALID_SEG_OFFSET) {
|
||||
search_result.seg_offsets_[i] = INVALID_SEG_OFFSET;
|
||||
search_result.element_indices_[i] = -1;
|
||||
} else {
|
||||
auto [doc_id, elem_index] =
|
||||
search_info.array_offsets_->ElementIDToRowID(
|
||||
element_ids[i]);
|
||||
search_result.seg_offsets_[i] = doc_id;
|
||||
search_result.element_indices_[i] = elem_index;
|
||||
}
|
||||
}
|
||||
search_result.element_level_ = true;
|
||||
}
|
||||
}
|
||||
search_result.total_nq_ = num_queries;
|
||||
search_result.unity_topK_ = topK;
|
||||
@ -150,6 +172,17 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
search_info.metric_type_,
|
||||
search_info.round_decimal_);
|
||||
|
||||
// For element-level search (embedding-search-embedding), we need to use
|
||||
// element count instead of row count
|
||||
bool is_element_level_search =
|
||||
field.get_data_type() == DataType::VECTOR_ARRAY &&
|
||||
query_offsets == nullptr;
|
||||
|
||||
if (is_element_level_search) {
|
||||
// embedding-search-embedding on embedding list pattern
|
||||
data_type = element_type;
|
||||
}
|
||||
|
||||
auto offset = 0;
|
||||
|
||||
auto vector_chunks = column->GetAllChunks(op_context);
|
||||
@ -157,6 +190,14 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
auto pw = vector_chunks[i];
|
||||
auto vec_data = pw.get()->Data();
|
||||
auto chunk_size = column->chunk_row_nums(i);
|
||||
|
||||
// For element-level search, get element count from VectorArrayOffsets
|
||||
if (is_element_level_search) {
|
||||
auto elem_offsets_pw = column->VectorArrayOffsets(op_context, i);
|
||||
// offsets[row_count] gives total element count in this chunk
|
||||
chunk_size = elem_offsets_pw.get()[chunk_size];
|
||||
}
|
||||
|
||||
auto raw_dataset =
|
||||
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
||||
|
||||
@ -201,8 +242,17 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
column->GetNumRowsUntilChunk(),
|
||||
final_qr.chunk_iterators());
|
||||
} else {
|
||||
if (search_info.array_offsets_ != nullptr) {
|
||||
auto [seg_offsets, elem_indicies] =
|
||||
final_qr.convert_to_element_offsets(
|
||||
search_info.array_offsets_.get());
|
||||
result.seg_offsets_ = std::move(seg_offsets);
|
||||
result.element_indices_ = std::move(elem_indicies);
|
||||
result.element_level_ = true;
|
||||
} else {
|
||||
result.seg_offsets_ = std::move(final_qr.mutable_offsets());
|
||||
}
|
||||
result.distances_ = std::move(final_qr.mutable_distances());
|
||||
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
|
||||
}
|
||||
result.unity_topK_ = query_dataset.topk;
|
||||
result.total_nq_ = query_dataset.num_queries;
|
||||
|
||||
@ -30,7 +30,7 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
|
||||
for (int64_t qn = 0; qn < num_queries_; ++qn) {
|
||||
auto offset = qn * topk_;
|
||||
|
||||
int64_t* __restrict__ left_ids = this->get_seg_offsets() + offset;
|
||||
int64_t* __restrict__ left_ids = this->get_offsets() + offset;
|
||||
float* __restrict__ left_distances = this->get_distances() + offset;
|
||||
|
||||
auto right_ids = right.get_ids() + offset;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "knowhere/index/index_node.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
|
||||
namespace milvus::query {
|
||||
class SubSearchResult {
|
||||
@ -31,7 +32,7 @@ class SubSearchResult {
|
||||
topk_(topk),
|
||||
round_decimal_(round_decimal),
|
||||
metric_type_(metric_type),
|
||||
seg_offsets_(num_queries * topk, INVALID_SEG_OFFSET),
|
||||
offsets_(num_queries * topk, INVALID_SEG_OFFSET),
|
||||
distances_(num_queries * topk, init_value(metric_type)),
|
||||
chunk_iterators_(std::move(iters)) {
|
||||
}
|
||||
@ -52,7 +53,7 @@ class SubSearchResult {
|
||||
topk_(other.topk_),
|
||||
round_decimal_(other.round_decimal_),
|
||||
metric_type_(std::move(other.metric_type_)),
|
||||
seg_offsets_(std::move(other.seg_offsets_)),
|
||||
offsets_(std::move(other.offsets_)),
|
||||
distances_(std::move(other.distances_)),
|
||||
chunk_iterators_(std::move(other.chunk_iterators_)) {
|
||||
}
|
||||
@ -77,12 +78,12 @@ class SubSearchResult {
|
||||
|
||||
const int64_t*
|
||||
get_ids() const {
|
||||
return seg_offsets_.data();
|
||||
return offsets_.data();
|
||||
}
|
||||
|
||||
int64_t*
|
||||
get_seg_offsets() {
|
||||
return seg_offsets_.data();
|
||||
get_offsets() {
|
||||
return offsets_.data();
|
||||
}
|
||||
|
||||
const float*
|
||||
@ -96,8 +97,8 @@ class SubSearchResult {
|
||||
}
|
||||
|
||||
auto&
|
||||
mutable_seg_offsets() {
|
||||
return seg_offsets_;
|
||||
mutable_offsets() {
|
||||
return offsets_;
|
||||
}
|
||||
|
||||
auto&
|
||||
@ -116,6 +117,27 @@ class SubSearchResult {
|
||||
return this->chunk_iterators_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<int64_t>, std::vector<int32_t>>
|
||||
convert_to_element_offsets(const IArrayOffsets* array_offsets) {
|
||||
std::vector<int64_t> doc_offsets;
|
||||
std::vector<int32_t> element_indices;
|
||||
doc_offsets.reserve(offsets_.size());
|
||||
element_indices.reserve(offsets_.size());
|
||||
for (size_t i = 0; i < offsets_.size(); i++) {
|
||||
if (offsets_[i] == INVALID_SEG_OFFSET) {
|
||||
doc_offsets.push_back(INVALID_SEG_OFFSET);
|
||||
element_indices.push_back(-1);
|
||||
} else {
|
||||
auto [doc_id, elem_index] =
|
||||
array_offsets->ElementIDToRowID(offsets_[i]);
|
||||
doc_offsets.push_back(doc_id);
|
||||
element_indices.push_back(elem_index);
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(doc_offsets),
|
||||
std::move(element_indices));
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool is_desc>
|
||||
void
|
||||
@ -126,7 +148,7 @@ class SubSearchResult {
|
||||
int64_t topk_;
|
||||
int64_t round_decimal_;
|
||||
knowhere::MetricType metric_type_;
|
||||
std::vector<int64_t> seg_offsets_;
|
||||
std::vector<int64_t> offsets_;
|
||||
std::vector<float> distances_;
|
||||
std::vector<knowhere::IndexNode::IteratorPtr> chunk_iterators_;
|
||||
};
|
||||
|
||||
@ -55,7 +55,7 @@ GenSubSearchResult(const int64_t nq,
|
||||
}
|
||||
}
|
||||
sub_result->mutable_distances() = std::move(distances);
|
||||
sub_result->mutable_seg_offsets() = std::move(ids);
|
||||
sub_result->mutable_offsets() = std::move(ids);
|
||||
return sub_result;
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ CheckSubSearchResult(const int64_t nq,
|
||||
auto ref_x = result_ref[n].top();
|
||||
result_ref[n].pop();
|
||||
auto index = n * topk + topk - 1 - k;
|
||||
auto id = result.get_seg_offsets()[index];
|
||||
auto id = result.get_offsets()[index];
|
||||
auto distance = result.get_distances()[index];
|
||||
ASSERT_EQ(id, ref_x);
|
||||
ASSERT_EQ(distance, ref_x);
|
||||
|
||||
@ -2676,22 +2676,62 @@ ChunkedSegmentSealedImpl::load_field_data_common(
|
||||
|
||||
bool generated_interim_index = generate_interim_index(field_id, num_rows);
|
||||
|
||||
std::unique_lock lck(mutex_);
|
||||
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
|
||||
"field {} data already loaded",
|
||||
field_id.get());
|
||||
set_bit(field_data_ready_bitset_, field_id, true);
|
||||
update_row_count(num_rows);
|
||||
if (generated_interim_index) {
|
||||
auto column = get_column(field_id);
|
||||
if (column) {
|
||||
column->ManualEvictCache();
|
||||
std::string struct_name;
|
||||
const FieldMeta* field_meta_ptr = nullptr;
|
||||
|
||||
{
|
||||
std::unique_lock lck(mutex_);
|
||||
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
|
||||
"field {} data already loaded",
|
||||
field_id.get());
|
||||
set_bit(field_data_ready_bitset_, field_id, true);
|
||||
update_row_count(num_rows);
|
||||
if (generated_interim_index) {
|
||||
auto column = get_column(field_id);
|
||||
if (column) {
|
||||
column->ManualEvictCache();
|
||||
}
|
||||
}
|
||||
if (data_type == DataType::GEOMETRY &&
|
||||
segcore_config_.get_enable_geometry_cache()) {
|
||||
// Construct GeometryCache for the entire field
|
||||
LoadGeometryCache(field_id, column);
|
||||
}
|
||||
|
||||
// Check if need to build ArrayOffsetsSealed for struct array fields
|
||||
if (data_type == DataType::ARRAY ||
|
||||
data_type == DataType::VECTOR_ARRAY) {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
const std::string& field_name = field_meta.get_name().get();
|
||||
|
||||
if (field_name.find('[') != std::string::npos &&
|
||||
field_name.find(']') != std::string::npos) {
|
||||
struct_name = field_name.substr(0, field_name.find('['));
|
||||
|
||||
auto it = struct_to_array_offsets_.find(struct_name);
|
||||
if (it != struct_to_array_offsets_.end()) {
|
||||
array_offsets_map_[field_id] = it->second;
|
||||
} else {
|
||||
field_meta_ptr = &field_meta; // need to build
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_type == DataType::GEOMETRY &&
|
||||
segcore_config_.get_enable_geometry_cache()) {
|
||||
// Construct GeometryCache for the entire field
|
||||
LoadGeometryCache(field_id, column);
|
||||
|
||||
// Build ArrayOffsetsSealed outside lock (expensive operation)
|
||||
if (field_meta_ptr) {
|
||||
auto new_offsets =
|
||||
ArrayOffsetsSealed::BuildFromSegment(this, *field_meta_ptr);
|
||||
|
||||
std::unique_lock lck(mutex_);
|
||||
// Double-check after re-acquiring lock
|
||||
auto it = struct_to_array_offsets_.find(struct_name);
|
||||
if (it == struct_to_array_offsets_.end()) {
|
||||
struct_to_array_offsets_[struct_name] = new_offsets;
|
||||
array_offsets_map_[field_id] = new_offsets;
|
||||
} else {
|
||||
array_offsets_map_[field_id] = it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -169,6 +169,15 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
||||
FieldId field_id,
|
||||
const std::string& nested_path) const override;
|
||||
|
||||
std::shared_ptr<const IArrayOffsets>
|
||||
GetArrayOffsets(FieldId field_id) const override {
|
||||
auto it = array_offsets_map_.find(field_id);
|
||||
if (it != array_offsets_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void
|
||||
BulkGetJsonData(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
@ -1049,6 +1058,14 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
||||
|
||||
// milvus storage internal api reader instance
|
||||
std::unique_ptr<milvus_storage::api::Reader> reader_;
|
||||
|
||||
// ArrayOffsetsSealed for element-level filtering on array fields
|
||||
// field_id -> ArrayOffsetsSealed mapping
|
||||
std::unordered_map<FieldId, std::shared_ptr<ArrayOffsetsSealed>>
|
||||
array_offsets_map_;
|
||||
// struct_name -> ArrayOffsetsSealed mapping (temporary during load)
|
||||
std::unordered_map<std::string, std::shared_ptr<ArrayOffsetsSealed>>
|
||||
struct_to_array_offsets_;
|
||||
};
|
||||
|
||||
inline SegmentSealedUPtr
|
||||
|
||||
@ -56,6 +56,167 @@ namespace milvus::segcore {
|
||||
|
||||
using namespace milvus::cachinglayer;
|
||||
|
||||
namespace {
|
||||
|
||||
void
|
||||
ExtractArrayLengthsFromFieldData(const std::vector<FieldDataPtr>& field_data,
|
||||
const FieldMeta& field_meta,
|
||||
int32_t* array_lengths) {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
int64_t offset = 0;
|
||||
|
||||
for (const auto& data : field_data) {
|
||||
auto num_rows = data->get_num_rows();
|
||||
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
// Get raw pointer to VectorArray data
|
||||
auto* raw_data = static_cast<const VectorArray*>(data->Data());
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
array_lengths[offset + i] = raw_data[i].length();
|
||||
}
|
||||
} else {
|
||||
// For regular array types (INT32, FLOAT, etc.)
|
||||
auto* raw_data = static_cast<const ArrayView*>(data->Data());
|
||||
for (int64_t i = 0; i < num_rows; ++i) {
|
||||
array_lengths[offset + i] = raw_data[i].length();
|
||||
}
|
||||
}
|
||||
offset += num_rows;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ExtractArrayLengths(const proto::schema::FieldData& field_data,
|
||||
const FieldMeta& field_meta,
|
||||
int64_t num_rows,
|
||||
int32_t* array_lengths) {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
const auto& vector_array = field_data.vectors().vector_array();
|
||||
int64_t dim = field_meta.get_dim();
|
||||
auto element_type = field_meta.get_element_type();
|
||||
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
const auto& vec_field = vector_array.data(i);
|
||||
int32_t array_len = 0;
|
||||
|
||||
switch (element_type) {
|
||||
case DataType::VECTOR_FLOAT:
|
||||
array_len = vec_field.float_vector().data_size() / dim;
|
||||
break;
|
||||
case DataType::VECTOR_FLOAT16:
|
||||
array_len = vec_field.float16_vector().size() / (dim * 2);
|
||||
break;
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
array_len = vec_field.bfloat16_vector().size() / (dim * 2);
|
||||
break;
|
||||
case DataType::VECTOR_BINARY:
|
||||
array_len = vec_field.binary_vector().size() / (dim / 8);
|
||||
break;
|
||||
case DataType::VECTOR_INT8:
|
||||
array_len = vec_field.int8_vector().size() / dim;
|
||||
break;
|
||||
default:
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"Unexpected VECTOR_ARRAY element type: {}",
|
||||
element_type);
|
||||
}
|
||||
|
||||
array_lengths[i] = array_len;
|
||||
}
|
||||
} else {
|
||||
// ARRAY: extract from scalars().array_data().data(i)
|
||||
const auto& array_data = field_data.scalars().array_data();
|
||||
auto element_type = field_meta.get_element_type();
|
||||
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
int32_t array_len = 0;
|
||||
|
||||
switch (element_type) {
|
||||
case DataType::BOOL:
|
||||
array_len = array_data.data(i).bool_data().data_size();
|
||||
break;
|
||||
case DataType::INT8:
|
||||
case DataType::INT16:
|
||||
case DataType::INT32:
|
||||
array_len = array_data.data(i).int_data().data_size();
|
||||
break;
|
||||
case DataType::INT64:
|
||||
array_len = array_data.data(i).long_data().data_size();
|
||||
break;
|
||||
case DataType::FLOAT:
|
||||
array_len = array_data.data(i).float_data().data_size();
|
||||
break;
|
||||
case DataType::DOUBLE:
|
||||
array_len = array_data.data(i).double_data().data_size();
|
||||
break;
|
||||
case DataType::STRING:
|
||||
case DataType::VARCHAR:
|
||||
array_len = array_data.data(i).string_data().data_size();
|
||||
break;
|
||||
default:
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"Unexpected array type: {}",
|
||||
element_type);
|
||||
}
|
||||
|
||||
array_lengths[i] = array_len;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if (field_meta.is_nullable() && field_data.valid_data_size() > 0) {
|
||||
const auto& valid_data = field_data.valid_data();
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (!valid_data[i]) {
|
||||
array_lengths[i] = 0; // null → empty array
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void
|
||||
SegmentGrowingImpl::InitializeArrayOffsets() {
|
||||
// Group fields by struct_name
|
||||
std::unordered_map<std::string, std::vector<FieldId>> struct_fields;
|
||||
|
||||
for (const auto& [field_id, field_meta] : schema_->get_fields()) {
|
||||
const auto& field_name = field_meta.get_name().get();
|
||||
|
||||
// Check if field belongs to a struct: format = "struct_name[field_name]"
|
||||
size_t bracket_pos = field_name.find('[');
|
||||
if (bracket_pos != std::string::npos && bracket_pos > 0) {
|
||||
std::string struct_name = field_name.substr(0, bracket_pos);
|
||||
struct_fields[struct_name].push_back(field_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Create one ArrayOffsetsGrowing per struct, shared by all its fields
|
||||
for (const auto& [struct_name, field_ids] : struct_fields) {
|
||||
auto array_offsets = std::make_shared<ArrayOffsetsGrowing>();
|
||||
|
||||
// Pick the first field as representative (any field works since array lengths are identical)
|
||||
FieldId representative_field = field_ids[0];
|
||||
|
||||
// Map all field_ids from this struct to the same ArrayOffsetsGrowing
|
||||
for (auto field_id : field_ids) {
|
||||
array_offsets_map_[field_id] = array_offsets;
|
||||
}
|
||||
|
||||
// Record representative field for Insert-time updates
|
||||
struct_representative_fields_.insert(representative_field);
|
||||
|
||||
LOG_INFO(
|
||||
"Created ArrayOffsetsGrowing for struct '{}' with {} fields, "
|
||||
"representative field_id={}",
|
||||
struct_name,
|
||||
field_ids.size(),
|
||||
representative_field.get());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
SegmentGrowingImpl::PreInsert(int64_t size) {
|
||||
auto reserved_begin = insert_record_.reserved.fetch_add(size);
|
||||
@ -174,6 +335,22 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
|
||||
insert_record_);
|
||||
}
|
||||
|
||||
// update ArrayOffsetsGrowing for struct fields
|
||||
if (struct_representative_fields_.count(field_id) > 0) {
|
||||
const auto& field_data =
|
||||
insert_record_proto->fields_data(data_offset);
|
||||
|
||||
std::vector<int32_t> array_lengths(num_rows);
|
||||
ExtractArrayLengths(
|
||||
field_data, field_meta, num_rows, array_lengths.data());
|
||||
|
||||
auto offsets_it = array_offsets_map_.find(field_id);
|
||||
if (offsets_it != array_offsets_map_.end()) {
|
||||
offsets_it->second->Insert(
|
||||
reserved_offset, array_lengths.data(), num_rows);
|
||||
}
|
||||
}
|
||||
|
||||
// index text.
|
||||
if (field_meta.enable_match()) {
|
||||
// TODO: iterate texts and call `AddText` instead of `AddTexts`. This may cost much more memory.
|
||||
@ -381,6 +558,23 @@ SegmentGrowingImpl::load_field_data_common(
|
||||
index->Reload();
|
||||
}
|
||||
|
||||
// update ArrayOffsetsGrowing for struct fields
|
||||
if (struct_representative_fields_.count(field_id) > 0) {
|
||||
std::vector<int32_t> array_lengths(num_rows);
|
||||
ExtractArrayLengthsFromFieldData(
|
||||
field_data, field_meta, array_lengths.data());
|
||||
|
||||
auto offsets_it = array_offsets_map_.find(field_id);
|
||||
if (offsets_it != array_offsets_map_.end()) {
|
||||
offsets_it->second->Insert(
|
||||
reserved_offset, array_lengths.data(), num_rows);
|
||||
}
|
||||
|
||||
LOG_INFO("Updated ArrayOffsetsGrowing for field {} with {} rows",
|
||||
field_id.get(),
|
||||
num_rows);
|
||||
}
|
||||
|
||||
// update the mem size
|
||||
stats_.mem_size += storage::GetByteSizeOfFieldDatas(field_data);
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "common/Types.h"
|
||||
#include "query/PlanNode.h"
|
||||
#include "common/GeometryCache.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
@ -330,6 +331,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
},
|
||||
segment_id) {
|
||||
this->CreateTextIndexes();
|
||||
this->InitializeArrayOffsets();
|
||||
}
|
||||
|
||||
~SegmentGrowingImpl() {
|
||||
@ -490,6 +492,15 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
"RemoveJsonStats not implemented for SegmentGrowingImpl");
|
||||
}
|
||||
|
||||
std::shared_ptr<const IArrayOffsets>
|
||||
GetArrayOffsets(FieldId field_id) const override {
|
||||
auto it = array_offsets_map_.find(field_id);
|
||||
if (it != array_offsets_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
int64_t
|
||||
num_chunk(FieldId field_id) const override;
|
||||
@ -586,6 +597,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
const std::shared_ptr<milvus_storage::api::Properties>& properties,
|
||||
int64_t index);
|
||||
|
||||
void
|
||||
InitializeArrayOffsets();
|
||||
|
||||
private:
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor_ = nullptr;
|
||||
SegcoreConfig segcore_config_;
|
||||
@ -609,6 +623,15 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
|
||||
// milvus storage internal api reader instance
|
||||
std::unique_ptr<milvus_storage::api::Reader> reader_;
|
||||
|
||||
// field_id -> ArrayOffsetsGrowing (for fast lookup via GetArrayOffsets)
|
||||
// Multiple field_ids from the same struct point to the same ArrayOffsetsGrowing
|
||||
std::unordered_map<FieldId, std::shared_ptr<ArrayOffsetsGrowing>>
|
||||
array_offsets_map_;
|
||||
|
||||
// Representative field_id for each struct (used to extract array lengths during Insert)
|
||||
// One field_id per struct, since all fields in the same struct have identical array lengths
|
||||
std::unordered_set<FieldId> struct_representative_fields_;
|
||||
};
|
||||
|
||||
inline SegmentGrowingPtr
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include <index/ScalarIndex.h>
|
||||
|
||||
#include "cachinglayer/CacheSlot.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Json.h"
|
||||
#include "common/OpContext.h"
|
||||
@ -243,6 +244,11 @@ class SegmentInterface {
|
||||
|
||||
virtual void
|
||||
Load(milvus::tracer::TraceContext& trace_ctx) = 0;
|
||||
|
||||
// Get IArrayOffsets for element-level filtering on array fields
|
||||
// Returns nullptr if the field doesn't have IArrayOffsets
|
||||
virtual std::shared_ptr<const IArrayOffsets>
|
||||
GetArrayOffsets(FieldId field_id) const = 0;
|
||||
};
|
||||
|
||||
// internal API for DSL calculation
|
||||
|
||||
@ -67,6 +67,10 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
|
||||
std::vector<float> distances(size);
|
||||
std::vector<int64_t> seg_offsets(size);
|
||||
std::vector<GroupByValueType> group_by_values(size);
|
||||
std::vector<int32_t> element_indices;
|
||||
if (search_result->element_level_) {
|
||||
element_indices.resize(size);
|
||||
}
|
||||
|
||||
uint32_t index = 0;
|
||||
for (int j = 0; j < total_nq_; j++) {
|
||||
@ -76,6 +80,10 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
|
||||
seg_offsets[index] = search_result->seg_offsets_[offset];
|
||||
group_by_values[index] =
|
||||
search_result->group_by_values_.value()[offset];
|
||||
if (search_result->element_level_) {
|
||||
element_indices[index] =
|
||||
search_result->element_indices_[offset];
|
||||
}
|
||||
index++;
|
||||
real_topks[j]++;
|
||||
}
|
||||
@ -84,6 +92,9 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
|
||||
search_result->distances_.swap(distances);
|
||||
search_result->seg_offsets_.swap(seg_offsets);
|
||||
search_result->group_by_values_.value().swap(group_by_values);
|
||||
if (search_result->element_level_) {
|
||||
search_result->element_indices_.swap(element_indices);
|
||||
}
|
||||
AssertInfo(search_result->primary_keys_.size() ==
|
||||
search_result->group_by_values_.value().size(),
|
||||
"Wrong size for group_by_values size after refresh:{}, "
|
||||
|
||||
@ -116,12 +116,17 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
|
||||
real_topks[i]++;
|
||||
offsets[valid_index] = offsets[index];
|
||||
distances[valid_index] = distances[index];
|
||||
if (search_result->element_level_)
|
||||
search_result->element_indices_[valid_index] =
|
||||
search_result->element_indices_[index];
|
||||
valid_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
offsets.resize(valid_index);
|
||||
distances.resize(valid_index);
|
||||
if (search_result->element_level_)
|
||||
search_result->element_indices_.resize(valid_index);
|
||||
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
|
||||
std::partial_sum(real_topks.begin(),
|
||||
real_topks.end(),
|
||||
@ -207,6 +212,10 @@ ReduceHelper::SortEqualScoresOneNQ(size_t nq_begin,
|
||||
PkType temp_pk =
|
||||
std::move(search_result->primary_keys_[start + i]);
|
||||
int64_t temp_offset = search_result->seg_offsets_[start + i];
|
||||
int32_t temp_elem_idx =
|
||||
search_result->element_level_
|
||||
? search_result->element_indices_[start + i]
|
||||
: -1;
|
||||
|
||||
size_t curr = i;
|
||||
while (indices[curr] != i) {
|
||||
@ -215,12 +224,20 @@ ReduceHelper::SortEqualScoresOneNQ(size_t nq_begin,
|
||||
std::move(search_result->primary_keys_[start + next]);
|
||||
search_result->seg_offsets_[start + curr] =
|
||||
search_result->seg_offsets_[start + next];
|
||||
if (search_result->element_level_) {
|
||||
search_result->element_indices_[start + curr] =
|
||||
search_result->element_indices_[start + next];
|
||||
}
|
||||
indices[curr] = curr; // Mark as processed
|
||||
curr = next;
|
||||
}
|
||||
|
||||
search_result->primary_keys_[start + curr] = std::move(temp_pk);
|
||||
search_result->seg_offsets_[start + curr] = temp_offset;
|
||||
if (search_result->element_level_) {
|
||||
search_result->element_indices_[start + curr] =
|
||||
temp_elem_idx;
|
||||
}
|
||||
indices[curr] = curr;
|
||||
}
|
||||
}
|
||||
@ -258,6 +275,10 @@ ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
|
||||
search_result->distances_[offset];
|
||||
search_result->seg_offsets_[index] =
|
||||
search_result->seg_offsets_[offset];
|
||||
if (search_result->element_level_) {
|
||||
search_result->element_indices_[index] =
|
||||
search_result->element_indices_[offset];
|
||||
}
|
||||
index++;
|
||||
real_topks[j]++;
|
||||
}
|
||||
@ -265,6 +286,9 @@ ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
|
||||
search_result->primary_keys_.resize(index);
|
||||
search_result->distances_.resize(index);
|
||||
search_result->seg_offsets_.resize(index);
|
||||
if (search_result->element_level_) {
|
||||
search_result->element_indices_.resize(index);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
@ -451,6 +475,15 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
|
||||
// reserve space for distances
|
||||
search_result_data->mutable_scores()->Resize(result_count, 0);
|
||||
|
||||
for (auto search_result : search_results_) {
|
||||
if (search_result->element_level_) {
|
||||
search_result_data->mutable_element_indices()
|
||||
->mutable_data()
|
||||
->Resize(result_count, -1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// fill pks and distances
|
||||
for (auto qi = nq_begin; qi < nq_end; qi++) {
|
||||
int64_t topk_count = 0;
|
||||
@ -499,6 +532,13 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
|
||||
|
||||
search_result_data->mutable_scores()->Set(
|
||||
loc, search_result->distances_[ki]);
|
||||
|
||||
if (search_result->element_level_) {
|
||||
search_result_data->mutable_element_indices()
|
||||
->mutable_data()
|
||||
->Set(loc, search_result->element_indices_[ki]);
|
||||
}
|
||||
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = {&search_result->output_fields_data_, ki};
|
||||
}
|
||||
|
||||
@ -50,6 +50,7 @@ set(MILVUS_TEST_FILES
|
||||
test_rust_result.cpp
|
||||
test_storage_v2_index_raw_data.cpp
|
||||
test_group_by_json.cpp
|
||||
test_element_filter.cpp
|
||||
)
|
||||
|
||||
if ( NOT (INDEX_ENGINE STREQUAL "cardinal") )
|
||||
|
||||
977
internal/core/unittest/test_element_filter.cpp
Normal file
977
internal/core/unittest/test_element_filter.cpp
Normal file
@ -0,0 +1,977 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <boost/format.hpp>
|
||||
|
||||
#include "common/Schema.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "query/Plan.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "test_utils/storage_test_utils.h"
|
||||
#include "test_utils/cachinglayer_test_utils.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
class ElementFilterSealed
|
||||
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||
protected:
|
||||
bool
|
||||
use_hints() const {
|
||||
return std::get<0>(GetParam());
|
||||
}
|
||||
bool
|
||||
load_index() const {
|
||||
return std::get<1>(GetParam());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ElementFilterSealed, RangeExpr) {
|
||||
bool with_hints = use_hints();
|
||||
bool with_load_index = load_index();
|
||||
// Step 1: Prepare schema with array field
|
||||
int dim = 4;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
|
||||
DataType::VECTOR_FLOAT,
|
||||
dim,
|
||||
knowhere::metric::L2);
|
||||
auto int_array_fid = schema->AddDebugArrayField(
|
||||
"structA[price_array]", DataType::INT32, false);
|
||||
|
||||
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
|
||||
size_t N = 500;
|
||||
int array_len = 3;
|
||||
|
||||
// Step 2: Generate test data
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
|
||||
|
||||
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
|
||||
auto* field_data = raw_data.raw_->mutable_fields_data(i);
|
||||
if (field_data->field_id() == int_array_fid.get()) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Clear();
|
||||
|
||||
for (int row = 0; row < N; row++) {
|
||||
auto* array_data = field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Add();
|
||||
|
||||
for (int elem = 0; elem < array_len; elem++) {
|
||||
int value = row * array_len + elem + 1;
|
||||
array_data->mutable_int_data()->mutable_data()->Add(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Create sealed segment with field data
|
||||
auto segment = CreateSealedWithFieldDataLoaded(schema, raw_data);
|
||||
|
||||
// Step 4: Load vector index for element-level search
|
||||
auto array_vec_values = raw_data.get_col<VectorFieldProto>(vec_fid);
|
||||
|
||||
// DataGen generates VECTOR_ARRAY with data in float_vector (flattened),
|
||||
// not in vector_array (nested structure)
|
||||
std::vector<float> vector_data(dim * N * array_len);
|
||||
for (int i = 0; i < N; i++) {
|
||||
const auto& float_vec = array_vec_values[i].float_vector().data();
|
||||
// float_vec contains array_len * dim floats
|
||||
for (int j = 0; j < array_len * dim; j++) {
|
||||
vector_data[i * array_len * dim + j] = float_vec[j];
|
||||
}
|
||||
}
|
||||
|
||||
// For element-level search, index all elements (N * array_len vectors)
|
||||
auto indexing = GenVecIndexing(N * array_len,
|
||||
dim,
|
||||
vector_data.data(),
|
||||
knowhere::IndexEnum::INDEX_HNSW);
|
||||
LoadIndexInfo load_index_info;
|
||||
load_index_info.field_id = vec_fid.get();
|
||||
load_index_info.index_params = GenIndexParams(indexing.get());
|
||||
load_index_info.cache_index =
|
||||
CreateTestCacheIndex("test", std::move(indexing));
|
||||
load_index_info.index_params["metric_type"] = knowhere::metric::L2;
|
||||
load_index_info.field_type = DataType::VECTOR_ARRAY;
|
||||
load_index_info.element_type = DataType::VECTOR_FLOAT;
|
||||
if (with_load_index) {
|
||||
segment->LoadIndex(load_index_info);
|
||||
}
|
||||
|
||||
int topK = 5;
|
||||
|
||||
// Step 5: Test with element-level filter
|
||||
// Query: Search array elements, filter by element_value in (100, 400)
|
||||
{
|
||||
std::string hints_line =
|
||||
with_hints ? R"(hints: "iterative_filter")" : "";
|
||||
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
element_filter_expr: <
|
||||
element_expr: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: Int32
|
||||
element_type: Int32
|
||||
is_element_level: true
|
||||
>
|
||||
lower_inclusive: false
|
||||
upper_inclusive: false
|
||||
lower_value: <
|
||||
int64_val: 100
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 400
|
||||
>
|
||||
>
|
||||
>
|
||||
predicate: <
|
||||
binary_arith_op_eval_range_expr: <
|
||||
column_info: <
|
||||
field_id: %3%
|
||||
data_type: Int64
|
||||
>
|
||||
arith_op: Mod
|
||||
right_operand: <
|
||||
int64_val: 2
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
int64_val: 0
|
||||
>
|
||||
>
|
||||
>
|
||||
struct_name: "structA"
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
%4%
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0">)") %
|
||||
vec_fid.get() % int_array_fid.get() %
|
||||
int64_fid.get() % hints_line);
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
|
||||
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
ASSERT_NE(plan, nullptr);
|
||||
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup(num_queries, dim, seed, true);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
|
||||
// Verify results
|
||||
ASSERT_NE(search_result, nullptr);
|
||||
|
||||
// In element-level mode, results should be element indices, not doc offsets
|
||||
ASSERT_TRUE(search_result->element_level_);
|
||||
ASSERT_FALSE(search_result->element_indices_.empty());
|
||||
// Also check seg_offsets_ which stores the doc IDs
|
||||
ASSERT_FALSE(search_result->seg_offsets_.empty());
|
||||
ASSERT_EQ(search_result->element_indices_.size(),
|
||||
search_result->seg_offsets_.size());
|
||||
|
||||
// Should have topK results per query
|
||||
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries);
|
||||
|
||||
std::cout << "Element-level search returned:" << std::endl;
|
||||
for (auto i = 0; i < search_result->seg_offsets_.size(); i++) {
|
||||
int64_t doc_id = search_result->seg_offsets_[i];
|
||||
int32_t elem_idx = search_result->element_indices_[i];
|
||||
float distance = search_result->distances_[i];
|
||||
|
||||
std::cout << "doc_id: " << doc_id << ", element_index: " << elem_idx
|
||||
<< ", distance: " << distance << std::endl;
|
||||
|
||||
// Verify the doc_id satisfies the predicate (id % 2 == 0)
|
||||
ASSERT_EQ(doc_id % 2, 0) << "Result doc_id " << doc_id
|
||||
<< " should satisfy (id % 2 == 0)";
|
||||
|
||||
// Verify element value is in range (100, 400)
|
||||
// Element value = doc_id * array_len + elem_idx + 1
|
||||
int element_value = doc_id * array_len + elem_idx + 1;
|
||||
ASSERT_GT(element_value, 100)
|
||||
<< "Element value " << element_value << " should be > 100";
|
||||
ASSERT_LT(element_value, 400)
|
||||
<< "Element value " << element_value << " should be < 400";
|
||||
}
|
||||
|
||||
// Verify distances are sorted (ascending for L2)
|
||||
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
|
||||
ASSERT_LE(search_result->distances_[i - 1],
|
||||
search_result->distances_[i])
|
||||
<< "Distances should be sorted in ascending order";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ElementFilterSealed, UnaryExpr) {
|
||||
bool with_hints = use_hints();
|
||||
bool with_load_index = load_index();
|
||||
// Step 1: Prepare schema with array field
|
||||
int dim = 4;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
|
||||
DataType::VECTOR_FLOAT,
|
||||
dim,
|
||||
knowhere::metric::L2);
|
||||
auto int_array_fid = schema->AddDebugArrayField(
|
||||
"structA[price_array]", DataType::INT32, false);
|
||||
|
||||
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
|
||||
size_t N = 500;
|
||||
int array_len = 3;
|
||||
|
||||
// Step 2: Generate test data
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
|
||||
|
||||
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
|
||||
auto* field_data = raw_data.raw_->mutable_fields_data(i);
|
||||
if (field_data->field_id() == int_array_fid.get()) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Clear();
|
||||
|
||||
for (int row = 0; row < N; row++) {
|
||||
auto* array_data = field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Add();
|
||||
|
||||
for (int elem = 0; elem < array_len; elem++) {
|
||||
int value = row * array_len + elem + 1;
|
||||
array_data->mutable_int_data()->mutable_data()->Add(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Create sealed segment with field data
|
||||
auto segment = CreateSealedWithFieldDataLoaded(schema, raw_data);
|
||||
|
||||
// Step 4: Load vector index for element-level search
|
||||
auto array_vec_values = raw_data.get_col<VectorFieldProto>(vec_fid);
|
||||
|
||||
// DataGen generates VECTOR_ARRAY with data in float_vector (flattened),
|
||||
// not in vector_array (nested structure)
|
||||
std::vector<float> vector_data(dim * N * array_len);
|
||||
for (int i = 0; i < N; i++) {
|
||||
const auto& float_vec = array_vec_values[i].float_vector().data();
|
||||
// float_vec contains array_len * dim floats
|
||||
for (int j = 0; j < array_len * dim; j++) {
|
||||
vector_data[i * array_len * dim + j] = float_vec[j];
|
||||
}
|
||||
}
|
||||
|
||||
// For element-level search, index all elements (N * array_len vectors)
|
||||
auto indexing = GenVecIndexing(N * array_len,
|
||||
dim,
|
||||
vector_data.data(),
|
||||
knowhere::IndexEnum::INDEX_HNSW);
|
||||
LoadIndexInfo load_index_info;
|
||||
load_index_info.field_id = vec_fid.get();
|
||||
load_index_info.index_params = GenIndexParams(indexing.get());
|
||||
load_index_info.cache_index =
|
||||
CreateTestCacheIndex("test", std::move(indexing));
|
||||
load_index_info.index_params["metric_type"] = knowhere::metric::L2;
|
||||
load_index_info.field_type = DataType::VECTOR_ARRAY;
|
||||
load_index_info.element_type = DataType::VECTOR_FLOAT;
|
||||
if (with_load_index) {
|
||||
segment->LoadIndex(load_index_info);
|
||||
}
|
||||
|
||||
int topK = 5;
|
||||
|
||||
// Step 5: Test with element-level filter
|
||||
// Query: Search array elements, filter by element_value < 10
|
||||
{
|
||||
std::string hints_line =
|
||||
with_hints ? R"(hints: "iterative_filter")" : "";
|
||||
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
element_filter_expr: <
|
||||
element_expr: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: Int32
|
||||
element_type: Int32
|
||||
is_element_level: true
|
||||
>
|
||||
op: GreaterThan
|
||||
value: <
|
||||
int64_val: 10
|
||||
>
|
||||
>
|
||||
>
|
||||
predicate: <
|
||||
binary_arith_op_eval_range_expr: <
|
||||
column_info: <
|
||||
field_id: %3%
|
||||
data_type: Int64
|
||||
>
|
||||
arith_op: Mod
|
||||
right_operand: <
|
||||
int64_val: 2
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
int64_val: 0
|
||||
>
|
||||
>
|
||||
>
|
||||
struct_name: "structA"
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
%4%
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0">)") %
|
||||
vec_fid.get() % int_array_fid.get() %
|
||||
int64_fid.get() % hints_line);
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
|
||||
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
ASSERT_NE(plan, nullptr);
|
||||
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup(num_queries, dim, seed, true);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
|
||||
// Verify results
|
||||
ASSERT_NE(search_result, nullptr);
|
||||
|
||||
// In element-level mode, results should be element indices, not doc offsets
|
||||
ASSERT_TRUE(search_result->element_level_);
|
||||
ASSERT_FALSE(search_result->element_indices_.empty());
|
||||
// Also check seg_offsets_ which stores the doc IDs
|
||||
ASSERT_FALSE(search_result->seg_offsets_.empty());
|
||||
ASSERT_EQ(search_result->element_indices_.size(),
|
||||
search_result->seg_offsets_.size());
|
||||
|
||||
// Should have topK results per query
|
||||
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries);
|
||||
|
||||
std::cout << "Element-level search returned:" << std::endl;
|
||||
for (auto i = 0; i < search_result->seg_offsets_.size(); i++) {
|
||||
std::cout << "doc_id: " << search_result->seg_offsets_[i]
|
||||
<< ", element_index: "
|
||||
<< search_result->element_indices_[i] << std::endl;
|
||||
std::cout << "distance: " << search_result->distances_[i]
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Verify distances are sorted (ascending for L2)
|
||||
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
|
||||
ASSERT_LE(search_result->distances_[i - 1],
|
||||
search_result->distances_[i])
|
||||
<< "Distances should be sorted in ascending order";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ElementFilter,
|
||||
ElementFilterSealed,
|
||||
::testing::Combine(::testing::Bool(), // with_hints: true/false
|
||||
::testing::Bool() // with_load_index: true/false
|
||||
),
|
||||
[](const ::testing::TestParamInfo<ElementFilterSealed::ParamType>& info) {
|
||||
bool with_hints = std::get<0>(info.param);
|
||||
bool with_load_index = std::get<1>(info.param);
|
||||
std::string name = "";
|
||||
name += with_hints ? "WithHints" : "WithoutHints";
|
||||
name += "_";
|
||||
name += with_load_index ? "WithLoadIndex" : "WithoutLoadIndex";
|
||||
return name;
|
||||
});
|
||||
|
||||
TEST(ElementFilter, GrowingSegmentArrayOffsetsGrowing) {
|
||||
int dim = 4;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
|
||||
DataType::VECTOR_FLOAT,
|
||||
dim,
|
||||
knowhere::metric::L2);
|
||||
auto int_array_fid = schema->AddDebugArrayField(
|
||||
"structA[price_array]", DataType::INT32, false);
|
||||
|
||||
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
|
||||
size_t N = 500;
|
||||
int array_len = 3;
|
||||
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
|
||||
|
||||
// Customize int_array data: doc i has elements [i*3+1, i*3+2, i*3+3]
|
||||
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
|
||||
auto* field_data = raw_data.raw_->mutable_fields_data(i);
|
||||
if (field_data->field_id() == int_array_fid.get()) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Clear();
|
||||
|
||||
for (int row = 0; row < N; row++) {
|
||||
auto* array_data = field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Add();
|
||||
|
||||
for (int elem = 0; elem < array_len; elem++) {
|
||||
int value = row * array_len + elem + 1;
|
||||
array_data->mutable_int_data()->mutable_data()->Add(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0,
|
||||
N,
|
||||
raw_data.row_ids_.data(),
|
||||
raw_data.timestamps_.data(),
|
||||
raw_data.raw_);
|
||||
|
||||
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
|
||||
ASSERT_NE(growing_impl, nullptr);
|
||||
|
||||
// Both fields should share the same ArrayOffsetsGrowing
|
||||
auto offsets_vec = growing_impl->GetArrayOffsets(vec_fid);
|
||||
auto offsets_int = growing_impl->GetArrayOffsets(int_array_fid);
|
||||
ASSERT_NE(offsets_vec, nullptr);
|
||||
ASSERT_NE(offsets_int, nullptr);
|
||||
|
||||
// Should point to the same object (shared)
|
||||
ASSERT_EQ(offsets_vec, offsets_int)
|
||||
<< "Fields in same struct should share ArrayOffsetsGrowing";
|
||||
|
||||
// Verify counts
|
||||
ASSERT_EQ(offsets_vec->GetRowCount(), N)
|
||||
<< "Should have " << N << " documents";
|
||||
ASSERT_EQ(offsets_vec->GetTotalElementCount(), N * array_len)
|
||||
<< "Should have " << N * array_len << " total elements";
|
||||
|
||||
for (int64_t doc_id = 0; doc_id < N; ++doc_id) {
|
||||
for (int32_t elem_idx = 0; elem_idx < array_len; ++elem_idx) {
|
||||
int64_t elem_id = doc_id * array_len + elem_idx;
|
||||
auto [mapped_doc, mapped_idx] =
|
||||
offsets_vec->ElementIDToRowID(elem_id);
|
||||
|
||||
ASSERT_EQ(mapped_doc, doc_id)
|
||||
<< "Element " << elem_id << " should map to doc " << doc_id;
|
||||
ASSERT_EQ(mapped_idx, elem_idx)
|
||||
<< "Element " << elem_id << " should have index " << elem_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ElementFilter, GrowingSegmentOutOfOrderInsert) {
|
||||
// Test out-of-order Insert handling in ArrayOffsetsGrowing
|
||||
int dim = 4;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
|
||||
DataType::VECTOR_FLOAT,
|
||||
dim,
|
||||
knowhere::metric::L2);
|
||||
auto int_array_fid = schema->AddDebugArrayField(
|
||||
"structA[price_array]", DataType::INT32, false);
|
||||
|
||||
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
|
||||
int array_len = 3;
|
||||
|
||||
// Create growing segment
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
|
||||
// Simulate out-of-order inserts
|
||||
// Insert docs [10-19], [0-9], [20-29]
|
||||
auto gen_batch = [&](int64_t start, int64_t count) {
|
||||
auto batch = DataGen(schema, count, 42 + start, start, 1, array_len);
|
||||
|
||||
// Customize int_array data
|
||||
for (int i = 0; i < batch.raw_->fields_data_size(); i++) {
|
||||
auto* field_data = batch.raw_->mutable_fields_data(i);
|
||||
if (field_data->field_id() == int_array_fid.get()) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Clear();
|
||||
|
||||
for (int row = 0; row < count; row++) {
|
||||
int64_t global_row = start + row;
|
||||
auto* array_data = field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Add();
|
||||
|
||||
for (int elem = 0; elem < array_len; elem++) {
|
||||
int value = global_row * array_len + elem + 1;
|
||||
array_data->mutable_int_data()->mutable_data()->Add(
|
||||
value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return batch;
|
||||
};
|
||||
|
||||
// Insert batch 2 first (docs 10-19) - should be cached
|
||||
auto batch2 = gen_batch(10, 10);
|
||||
segment->PreInsert(10);
|
||||
segment->Insert(
|
||||
10, 10, batch2.row_ids_.data(), batch2.timestamps_.data(), batch2.raw_);
|
||||
|
||||
// Insert batch 1 (docs 0-9) - should trigger drain of batch 2
|
||||
auto batch1 = gen_batch(0, 10);
|
||||
segment->PreInsert(10);
|
||||
segment->Insert(
|
||||
0, 10, batch1.row_ids_.data(), batch1.timestamps_.data(), batch1.raw_);
|
||||
|
||||
// Insert batch 3 (docs 25-34) - should be cached (gap at 20-24)
|
||||
auto batch3 = gen_batch(25, 10);
|
||||
segment->PreInsert(10);
|
||||
segment->Insert(
|
||||
25, 10, batch3.row_ids_.data(), batch3.timestamps_.data(), batch3.raw_);
|
||||
|
||||
// Verify ArrayOffsetsGrowing
|
||||
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
|
||||
ASSERT_NE(growing_impl, nullptr);
|
||||
|
||||
auto offsets = growing_impl->GetArrayOffsets(vec_fid);
|
||||
ASSERT_NE(offsets, nullptr);
|
||||
|
||||
// After inserting docs [0-19] (batch3 cached due to gap), committed count should be 20
|
||||
ASSERT_EQ(offsets->GetRowCount(), 20)
|
||||
<< "Should have committed docs 0-19, batch3 cached";
|
||||
ASSERT_EQ(offsets->GetTotalElementCount(), 20 * array_len)
|
||||
<< "Should have 20 docs worth of elements";
|
||||
|
||||
// Verify mapping for committed docs
|
||||
for (int64_t doc_id = 0; doc_id < 20; ++doc_id) {
|
||||
for (int32_t elem_idx = 0; elem_idx < array_len; ++elem_idx) {
|
||||
int64_t elem_id = doc_id * array_len + elem_idx;
|
||||
auto [mapped_doc, mapped_idx] = offsets->ElementIDToRowID(elem_id);
|
||||
|
||||
ASSERT_EQ(mapped_doc, doc_id)
|
||||
<< "Element " << elem_id << " should map to doc " << doc_id;
|
||||
ASSERT_EQ(mapped_idx, elem_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parameterized test fixture for GrowingIterativeRangeExpr
|
||||
class ElementFilterGrowing : public ::testing::TestWithParam<bool> {
|
||||
protected:
|
||||
bool
|
||||
use_hints() const {
|
||||
return GetParam();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ElementFilterGrowing, RangeExpr) {
|
||||
bool with_hints = use_hints();
|
||||
int dim = 4;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
|
||||
DataType::VECTOR_FLOAT,
|
||||
dim,
|
||||
knowhere::metric::L2);
|
||||
auto int_array_fid = schema->AddDebugArrayField(
|
||||
"structA[price_array]", DataType::INT32, false);
|
||||
|
||||
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
|
||||
schema->set_primary_field_id(int64_fid);
|
||||
|
||||
size_t N = 500;
|
||||
int array_len = 3;
|
||||
|
||||
// Generate test data
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
|
||||
|
||||
// Customize int_array data: doc i has elements [i*3+1, i*3+2, i*3+3]
|
||||
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
|
||||
auto* field_data = raw_data.raw_->mutable_fields_data(i);
|
||||
if (field_data->field_id() == int_array_fid.get()) {
|
||||
field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Clear();
|
||||
|
||||
for (int row = 0; row < N; row++) {
|
||||
auto* array_data = field_data->mutable_scalars()
|
||||
->mutable_array_data()
|
||||
->mutable_data()
|
||||
->Add();
|
||||
|
||||
for (int elem = 0; elem < array_len; elem++) {
|
||||
int value = row * array_len + elem + 1;
|
||||
array_data->mutable_int_data()->mutable_data()->Add(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Create growing segment and insert data
|
||||
auto segment = CreateGrowingSegment(schema, empty_index_meta);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0,
|
||||
N,
|
||||
raw_data.row_ids_.data(),
|
||||
raw_data.timestamps_.data(),
|
||||
raw_data.raw_);
|
||||
|
||||
// Verify ArrayOffsetsGrowing was built
|
||||
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
|
||||
ASSERT_NE(growing_impl, nullptr);
|
||||
auto offsets = growing_impl->GetArrayOffsets(vec_fid);
|
||||
ASSERT_NE(offsets, nullptr);
|
||||
ASSERT_EQ(offsets->GetRowCount(), N);
|
||||
ASSERT_EQ(offsets->GetTotalElementCount(), N * array_len);
|
||||
|
||||
int topK = 5;
|
||||
|
||||
// Execute element-level search with iterative filter
|
||||
// Query: Search array elements where (id % 2 == 0) AND (price_array element in range (100, 400))
|
||||
{
|
||||
std::string hints_line =
|
||||
with_hints ? R"(hints: "iterative_filter")" : "";
|
||||
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
element_filter_expr: <
|
||||
element_expr: <
|
||||
binary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: Int32
|
||||
element_type: Int32
|
||||
is_element_level: true
|
||||
>
|
||||
lower_inclusive: false
|
||||
upper_inclusive: false
|
||||
lower_value: <
|
||||
int64_val: 100
|
||||
>
|
||||
upper_value: <
|
||||
int64_val: 400
|
||||
>
|
||||
>
|
||||
>
|
||||
predicate: <
|
||||
binary_arith_op_eval_range_expr: <
|
||||
column_info: <
|
||||
field_id: %3%
|
||||
data_type: Int64
|
||||
>
|
||||
arith_op: Mod
|
||||
right_operand: <
|
||||
int64_val: 2
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
int64_val: 0
|
||||
>
|
||||
>
|
||||
>
|
||||
struct_name: "structA"
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
%4%
|
||||
search_params: "{\"ef\": 50}"
|
||||
>
|
||||
placeholder_tag: "$0">)") %
|
||||
vec_fid.get() % int_array_fid.get() %
|
||||
int64_fid.get() % hints_line);
|
||||
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
|
||||
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
|
||||
ASSERT_NE(plan, nullptr);
|
||||
|
||||
auto num_queries = 1;
|
||||
auto seed = 1024;
|
||||
auto ph_group_raw =
|
||||
CreatePlaceholderGroup(num_queries, dim, seed, true);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
auto search_result =
|
||||
segment->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
|
||||
// Verify results
|
||||
ASSERT_NE(search_result, nullptr);
|
||||
|
||||
// In element-level mode, results should contain element indices
|
||||
ASSERT_TRUE(search_result->element_level_)
|
||||
<< "Search should be in element-level mode";
|
||||
ASSERT_FALSE(search_result->element_indices_.empty())
|
||||
<< "Should have element indices";
|
||||
ASSERT_FALSE(search_result->seg_offsets_.empty())
|
||||
<< "Should have doc offsets";
|
||||
ASSERT_EQ(search_result->element_indices_.size(),
|
||||
search_result->seg_offsets_.size())
|
||||
<< "Element indices and doc offsets should match in size";
|
||||
|
||||
// Should have topK results per query
|
||||
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries)
|
||||
<< "Should not exceed topK results";
|
||||
|
||||
std::cout << "Growing segment element-level search results:"
|
||||
<< std::endl;
|
||||
for (size_t i = 0; i < search_result->seg_offsets_.size(); i++) {
|
||||
int64_t doc_id = search_result->seg_offsets_[i];
|
||||
int32_t elem_idx = search_result->element_indices_[i];
|
||||
float distance = search_result->distances_[i];
|
||||
|
||||
std::cout << " [" << i << "] doc_id=" << doc_id
|
||||
<< ", element_index=" << elem_idx
|
||||
<< ", distance=" << distance << std::endl;
|
||||
|
||||
// Verify the doc_id satisfies the predicate (id % 2 == 0)
|
||||
ASSERT_EQ(doc_id % 2, 0) << "Result doc_id " << doc_id
|
||||
<< " should satisfy (id % 2 == 0)";
|
||||
|
||||
// Verify element_idx is valid
|
||||
ASSERT_GE(elem_idx, 0) << "Element index should be >= 0";
|
||||
ASSERT_LT(elem_idx, array_len)
|
||||
<< "Element index should be < array_len";
|
||||
|
||||
// Verify element value is in range (100, 400)
|
||||
// Element value = doc_id * array_len + elem_idx + 1
|
||||
int element_value = doc_id * array_len + elem_idx + 1;
|
||||
ASSERT_GT(element_value, 100)
|
||||
<< "Element value " << element_value << " should be > 100";
|
||||
ASSERT_LT(element_value, 400)
|
||||
<< "Element value " << element_value << " should be < 400";
|
||||
}
|
||||
|
||||
// Verify distances are sorted (ascending for L2)
|
||||
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
|
||||
ASSERT_LE(search_result->distances_[i - 1],
|
||||
search_result->distances_[i])
|
||||
<< "Distances should be sorted in ascending order";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ElementFilter,
|
||||
ElementFilterGrowing,
|
||||
::testing::Bool(), // with_hints: true/false
|
||||
[](const ::testing::TestParamInfo<ElementFilterGrowing::ParamType>& info) {
|
||||
bool with_hints = info.param;
|
||||
return with_hints ? "WithHints" : "WithoutHints";
|
||||
});
|
||||
|
||||
// Unit tests for ArrayOffsetsGrowing
|
||||
TEST(ArrayOffsetsGrowing, PurePendingThenDrain) {
|
||||
// Test: first insert goes entirely to pending, second insert triggers drain
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// First insert: rows 2-4, all go to pending (committed_row_count_ = 0)
|
||||
std::vector<int32_t> lens1 = {
|
||||
3, 2, 4}; // row 2: 3 elems, row 3: 2 elems, row 4: 4 elems
|
||||
offsets.Insert(2, lens1.data(), 3);
|
||||
|
||||
ASSERT_EQ(offsets.GetRowCount(), 0) << "No rows should be committed yet";
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 0)
|
||||
<< "No elements should exist yet";
|
||||
|
||||
// Second insert: rows 0-1, triggers drain of pending rows 2-4
|
||||
std::vector<int32_t> lens2 = {2, 3}; // row 0: 2 elems, row 1: 3 elems
|
||||
offsets.Insert(0, lens2.data(), 2);
|
||||
|
||||
ASSERT_EQ(offsets.GetRowCount(), 5) << "All 5 rows should be committed";
|
||||
// Total elements: 2 + 3 + 3 + 2 + 4 = 14
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 14);
|
||||
|
||||
// Verify ElementIDToRowID mapping
|
||||
// Row 0: elem 0-1, Row 1: elem 2-4, Row 2: elem 5-7, Row 3: elem 8-9, Row 4: elem 10-13
|
||||
std::vector<std::pair<int32_t, int32_t>> expected = {
|
||||
{0, 0},
|
||||
{0, 1}, // row 0
|
||||
{1, 0},
|
||||
{1, 1},
|
||||
{1, 2}, // row 1
|
||||
{2, 0},
|
||||
{2, 1},
|
||||
{2, 2}, // row 2
|
||||
{3, 0},
|
||||
{3, 1}, // row 3
|
||||
{4, 0},
|
||||
{4, 1},
|
||||
{4, 2},
|
||||
{4, 3} // row 4
|
||||
};
|
||||
|
||||
for (int32_t elem_id = 0; elem_id < 14; ++elem_id) {
|
||||
auto [row_id, elem_idx] = offsets.ElementIDToRowID(elem_id);
|
||||
ASSERT_EQ(row_id, expected[elem_id].first)
|
||||
<< "elem_id " << elem_id << " should map to row "
|
||||
<< expected[elem_id].first;
|
||||
ASSERT_EQ(elem_idx, expected[elem_id].second)
|
||||
<< "elem_id " << elem_id << " should have elem_idx "
|
||||
<< expected[elem_id].second;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ArrayOffsetsGrowing, ElementIDRangeOfRow) {
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert 4 rows with varying element counts
|
||||
std::vector<int32_t> lens = {3, 0, 2, 5}; // includes empty array
|
||||
offsets.Insert(0, lens.data(), 4);
|
||||
|
||||
ASSERT_EQ(offsets.GetRowCount(), 4);
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 10); // 3 + 0 + 2 + 5
|
||||
|
||||
// Verify ElementIDRangeOfRow
|
||||
auto [start0, end0] = offsets.ElementIDRangeOfRow(0);
|
||||
ASSERT_EQ(start0, 0);
|
||||
ASSERT_EQ(end0, 3);
|
||||
|
||||
auto [start1, end1] = offsets.ElementIDRangeOfRow(1);
|
||||
ASSERT_EQ(start1, 3);
|
||||
ASSERT_EQ(end1, 3); // empty array
|
||||
|
||||
auto [start2, end2] = offsets.ElementIDRangeOfRow(2);
|
||||
ASSERT_EQ(start2, 3);
|
||||
ASSERT_EQ(end2, 5);
|
||||
|
||||
auto [start3, end3] = offsets.ElementIDRangeOfRow(3);
|
||||
ASSERT_EQ(start3, 5);
|
||||
ASSERT_EQ(end3, 10);
|
||||
|
||||
// Boundary: row_id == row_count returns (total, total)
|
||||
auto [start4, end4] = offsets.ElementIDRangeOfRow(4);
|
||||
ASSERT_EQ(start4, 10);
|
||||
ASSERT_EQ(end4, 10);
|
||||
}
|
||||
|
||||
TEST(ArrayOffsetsGrowing, MultiplePendingBatches) {
|
||||
// Test multiple pending batches being drained in order
|
||||
ArrayOffsetsGrowing offsets;
|
||||
|
||||
// Insert row 5 first
|
||||
std::vector<int32_t> lens5 = {2};
|
||||
offsets.Insert(5, lens5.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 0);
|
||||
|
||||
// Insert row 3
|
||||
std::vector<int32_t> lens3 = {3};
|
||||
offsets.Insert(3, lens3.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 0);
|
||||
|
||||
// Insert row 1
|
||||
std::vector<int32_t> lens1 = {1};
|
||||
offsets.Insert(1, lens1.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 0);
|
||||
|
||||
// Insert row 0 - should drain row 1, but not 3 or 5 (gap at 2)
|
||||
std::vector<int32_t> lens0 = {2};
|
||||
offsets.Insert(0, lens0.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 2) << "Should commit rows 0-1";
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 3); // 2 + 1
|
||||
|
||||
// Insert row 2 - should drain rows 3, but not 5 (gap at 4)
|
||||
std::vector<int32_t> lens2 = {1};
|
||||
offsets.Insert(2, lens2.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 4) << "Should commit rows 0-3";
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 7); // 2 + 1 + 1 + 3
|
||||
|
||||
// Insert row 4 - should drain row 5
|
||||
std::vector<int32_t> lens4 = {2};
|
||||
offsets.Insert(4, lens4.data(), 1);
|
||||
ASSERT_EQ(offsets.GetRowCount(), 6) << "Should commit rows 0-5";
|
||||
ASSERT_EQ(offsets.GetTotalElementCount(), 11); // 2 + 1 + 1 + 3 + 2 + 2
|
||||
|
||||
// Verify final mapping
|
||||
// Row 0: elem 0-1, Row 1: elem 2, Row 2: elem 3, Row 3: elem 4-6, Row 4: elem 7-8, Row 5: elem 9-10
|
||||
auto [r0, i0] = offsets.ElementIDToRowID(0);
|
||||
ASSERT_EQ(r0, 0);
|
||||
ASSERT_EQ(i0, 0);
|
||||
|
||||
auto [r2, i2] = offsets.ElementIDToRowID(2);
|
||||
ASSERT_EQ(r2, 1);
|
||||
ASSERT_EQ(i2, 0);
|
||||
|
||||
auto [r4, i4] = offsets.ElementIDToRowID(4);
|
||||
ASSERT_EQ(r4, 3);
|
||||
ASSERT_EQ(i4, 0);
|
||||
|
||||
auto [r7, i7] = offsets.ElementIDToRowID(7);
|
||||
ASSERT_EQ(r7, 4);
|
||||
ASSERT_EQ(i7, 0);
|
||||
|
||||
auto [r10, i10] = offsets.ElementIDToRowID(10);
|
||||
ASSERT_EQ(r10, 5);
|
||||
ASSERT_EQ(i10, 1);
|
||||
}
|
||||
@ -191,7 +191,7 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
SearchResult sr;
|
||||
sr.total_nq_ = num_queries;
|
||||
sr.unity_topK_ = topk;
|
||||
sr.seg_offsets_ = std::move(sub_result.mutable_seg_offsets());
|
||||
sr.seg_offsets_ = std::move(sub_result.mutable_offsets());
|
||||
sr.distances_ = std::move(sub_result.mutable_distances());
|
||||
|
||||
auto json = SearchResultToJson(sr);
|
||||
|
||||
@ -1579,7 +1579,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
||||
for (auto q = 0; q < num_queries; q++) {
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
auto offset = q * topk + k;
|
||||
auto seg_offset = sub_result.get_seg_offsets()[offset];
|
||||
auto seg_offset = sub_result.get_offsets()[offset];
|
||||
ASSERT_EQ(std::get<std::string>(sr->primary_keys_[offset]),
|
||||
str_col[seg_offset]);
|
||||
ASSERT_EQ(retrieved_str_col[offset], str_col[seg_offset]);
|
||||
|
||||
@ -1151,7 +1151,10 @@ CreatePlaceholderGroup(int64_t num_queries,
|
||||
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
auto
|
||||
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
|
||||
CreatePlaceholderGroup(int64_t num_queries,
|
||||
int dim,
|
||||
int64_t seed = 42,
|
||||
bool element_level = false) {
|
||||
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
|
||||
assert(dim % 8 == 0);
|
||||
}
|
||||
@ -1162,6 +1165,7 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(TraitType::placeholder_type);
|
||||
value->set_element_level(element_level);
|
||||
// TODO caiyd: need update for Int8Vector
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
std::default_random_engine e(seed);
|
||||
|
||||
@ -3,53 +3,55 @@ grammar Plan;
|
||||
expr:
|
||||
Identifier (op1=(ADD | SUB) INTERVAL interval_string=StringLiteral)? op2=(LT | LE | GT | GE | EQ | NE) ISO compare_string=StringLiteral # TimestamptzCompareForward
|
||||
| ISO compare_string=StringLiteral op2=(LT | LE | GT | GE | EQ | NE) Identifier (op1=(ADD | SUB) INTERVAL interval_string=StringLiteral)? # TimestamptzCompareReverse
|
||||
| IntegerConstant # Integer
|
||||
| FloatingConstant # Floating
|
||||
| BooleanConstant # Boolean
|
||||
| StringLiteral # String
|
||||
| (Identifier|Meta) # Identifier
|
||||
| JSONIdentifier # JSONIdentifier
|
||||
| LBRACE Identifier RBRACE # TemplateVariable
|
||||
| '(' expr ')' # Parens
|
||||
| '[' expr (',' expr)* ','? ']' # Array
|
||||
| EmptyArray # EmptyArray
|
||||
| EXISTS expr # Exists
|
||||
| expr LIKE StringLiteral # Like
|
||||
| TEXTMATCH'('Identifier',' StringLiteral (',' textMatchOption)? ')' # TextMatch
|
||||
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
|
||||
| RANDOMSAMPLE'(' expr ')' # RandomSample
|
||||
| expr POW expr # Power
|
||||
| op = (ADD | SUB | BNOT | NOT) expr # Unary
|
||||
// | '(' typeName ')' expr # Cast
|
||||
| expr op = (MUL | DIV | MOD) expr # MulDivMod
|
||||
| expr op = (ADD | SUB) expr # AddSub
|
||||
| expr op = (SHL | SHR) expr # Shift
|
||||
| expr op = NOT? IN expr # Term
|
||||
| (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains
|
||||
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
|
||||
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
|
||||
| STEuqals'('Identifier','StringLiteral')' # STEuqals
|
||||
| STTouches'('Identifier','StringLiteral')' # STTouches
|
||||
| STOverlaps'('Identifier','StringLiteral')' # STOverlaps
|
||||
| STCrosses'('Identifier','StringLiteral')' # STCrosses
|
||||
| STContains'('Identifier','StringLiteral')' # STContains
|
||||
| STIntersects'('Identifier','StringLiteral')' # STIntersects
|
||||
| STWithin'('Identifier','StringLiteral')' # STWithin
|
||||
| STDWithin'('Identifier','StringLiteral',' expr')' # STDWithin
|
||||
| STIsValid'('Identifier')' # STIsValid
|
||||
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
|
||||
| Identifier '(' ( expr (',' expr )* ','? )? ')' # Call
|
||||
| expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range
|
||||
| expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange
|
||||
| expr op = (LT | LE | GT | GE) expr # Relational
|
||||
| expr op = (EQ | NE) expr # Equality
|
||||
| expr BAND expr # BitAnd
|
||||
| expr BXOR expr # BitXor
|
||||
| expr BOR expr # BitOr
|
||||
| expr AND expr # LogicalAnd
|
||||
| expr OR expr # LogicalOr
|
||||
| (Identifier | JSONIdentifier) ISNULL # IsNull
|
||||
| (Identifier | JSONIdentifier) ISNOTNULL # IsNotNull;
|
||||
| IntegerConstant # Integer
|
||||
| FloatingConstant # Floating
|
||||
| BooleanConstant # Boolean
|
||||
| StringLiteral # String
|
||||
| (Identifier|Meta) # Identifier
|
||||
| JSONIdentifier # JSONIdentifier
|
||||
| StructSubFieldIdentifier # StructSubField
|
||||
| LBRACE Identifier RBRACE # TemplateVariable
|
||||
| '(' expr ')' # Parens
|
||||
| '[' expr (',' expr)* ','? ']' # Array
|
||||
| EmptyArray # EmptyArray
|
||||
| EXISTS expr # Exists
|
||||
| expr LIKE StringLiteral # Like
|
||||
| TEXTMATCH'('Identifier',' StringLiteral (',' textMatchOption)? ')' # TextMatch
|
||||
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
|
||||
| RANDOMSAMPLE'(' expr ')' # RandomSample
|
||||
| ElementFilter'('Identifier',' expr')' # ElementFilter
|
||||
| expr POW expr # Power
|
||||
| op = (ADD | SUB | BNOT | NOT) expr # Unary
|
||||
// | '(' typeName ')' expr # Cast
|
||||
| expr op = (MUL | DIV | MOD) expr # MulDivMod
|
||||
| expr op = (ADD | SUB) expr # AddSub
|
||||
| expr op = (SHL | SHR) expr # Shift
|
||||
| expr op = NOT? IN expr # Term
|
||||
| (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains
|
||||
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
|
||||
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
|
||||
| STEuqals'('Identifier','StringLiteral')' # STEuqals
|
||||
| STTouches'('Identifier','StringLiteral')' # STTouches
|
||||
| STOverlaps'('Identifier','StringLiteral')' # STOverlaps
|
||||
| STCrosses'('Identifier','StringLiteral')' # STCrosses
|
||||
| STContains'('Identifier','StringLiteral')' # STContains
|
||||
| STIntersects'('Identifier','StringLiteral')' # STIntersects
|
||||
| STWithin'('Identifier','StringLiteral')' # STWithin
|
||||
| STDWithin'('Identifier','StringLiteral',' expr')' # STDWithin
|
||||
| STIsValid'('Identifier')' # STIsValid
|
||||
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
|
||||
| Identifier '(' ( expr (',' expr )* ','? )? ')' # Call
|
||||
| expr op1 = (LT | LE) (Identifier | JSONIdentifier | StructSubFieldIdentifier) op2 = (LT | LE) expr # Range
|
||||
| expr op1 = (GT | GE) (Identifier | JSONIdentifier | StructSubFieldIdentifier) op2 = (GT | GE) expr # ReverseRange
|
||||
| expr op = (LT | LE | GT | GE) expr # Relational
|
||||
| expr op = (EQ | NE) expr # Equality
|
||||
| expr BAND expr # BitAnd
|
||||
| expr BXOR expr # BitXor
|
||||
| expr BOR expr # BitOr
|
||||
| expr AND expr # LogicalAnd
|
||||
| expr OR expr # LogicalOr
|
||||
| (Identifier | JSONIdentifier) ISNULL # IsNull
|
||||
| (Identifier | JSONIdentifier) ISNOTNULL # IsNotNull;
|
||||
|
||||
textMatchOption:
|
||||
MINIMUM_SHOULD_MATCH ASSIGN IntegerConstant;
|
||||
@ -115,6 +117,7 @@ ArrayContains: 'array_contains' | 'ARRAY_CONTAINS';
|
||||
ArrayContainsAll: 'array_contains_all' | 'ARRAY_CONTAINS_ALL';
|
||||
ArrayContainsAny: 'array_contains_any' | 'ARRAY_CONTAINS_ANY';
|
||||
ArrayLength: 'array_length' | 'ARRAY_LENGTH';
|
||||
ElementFilter: 'element_filter' | 'ELEMENT_FILTER';
|
||||
|
||||
STEuqals:'st_equals' | 'ST_EQUALS';
|
||||
STTouches:'st_touches' | 'ST_TOUCHES';
|
||||
@ -143,6 +146,7 @@ Meta: '$meta';
|
||||
|
||||
StringLiteral: EncodingPrefix? ('"' DoubleSCharSequence? '"' | '\'' SingleSCharSequence? '\'');
|
||||
JSONIdentifier: (Identifier | Meta)('[' (StringLiteral | DecimalConstant) ']')+;
|
||||
StructSubFieldIdentifier: '$[' Identifier ']';
|
||||
|
||||
fragment EncodingPrefix: 'u8' | 'u' | 'U' | 'L';
|
||||
|
||||
|
||||
@ -40,6 +40,14 @@ func FillExpressionValue(expr *planpb.Expr, templateValues map[string]*planpb.Ge
|
||||
return FillJSONContainsExpressionValue(e.JsonContainsExpr, templateValues)
|
||||
case *planpb.Expr_RandomSampleExpr:
|
||||
return FillExpressionValue(expr.GetExpr().(*planpb.Expr_RandomSampleExpr).RandomSampleExpr.GetPredicate(), templateValues)
|
||||
case *planpb.Expr_ElementFilterExpr:
|
||||
if err := FillExpressionValue(e.ElementFilterExpr.GetElementExpr(), templateValues); err != nil {
|
||||
return err
|
||||
}
|
||||
if e.ElementFilterExpr.GetPredicate() != nil {
|
||||
return FillExpressionValue(e.ElementFilterExpr.GetPredicate(), templateValues)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("this expression no need to fill placeholder with expr type: %T", e)
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -46,24 +46,26 @@ ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
STEuqals=49
|
||||
STTouches=50
|
||||
STOverlaps=51
|
||||
STCrosses=52
|
||||
STContains=53
|
||||
STIntersects=54
|
||||
STWithin=55
|
||||
STDWithin=56
|
||||
STIsValid=57
|
||||
BooleanConstant=58
|
||||
IntegerConstant=59
|
||||
FloatingConstant=60
|
||||
Identifier=61
|
||||
Meta=62
|
||||
StringLiteral=63
|
||||
JSONIdentifier=64
|
||||
Whitespace=65
|
||||
Newline=66
|
||||
ElementFilter=49
|
||||
STEuqals=50
|
||||
STTouches=51
|
||||
STOverlaps=52
|
||||
STCrosses=53
|
||||
STContains=54
|
||||
STIntersects=55
|
||||
STWithin=56
|
||||
STDWithin=57
|
||||
STIsValid=58
|
||||
BooleanConstant=59
|
||||
IntegerConstant=60
|
||||
FloatingConstant=61
|
||||
Identifier=62
|
||||
Meta=63
|
||||
StringLiteral=64
|
||||
JSONIdentifier=65
|
||||
StructSubFieldIdentifier=66
|
||||
Whitespace=67
|
||||
Newline=68
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -90,4 +92,4 @@ Newline=66
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=62
|
||||
'$meta'=63
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -46,24 +46,26 @@ ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
STEuqals=49
|
||||
STTouches=50
|
||||
STOverlaps=51
|
||||
STCrosses=52
|
||||
STContains=53
|
||||
STIntersects=54
|
||||
STWithin=55
|
||||
STDWithin=56
|
||||
STIsValid=57
|
||||
BooleanConstant=58
|
||||
IntegerConstant=59
|
||||
FloatingConstant=60
|
||||
Identifier=61
|
||||
Meta=62
|
||||
StringLiteral=63
|
||||
JSONIdentifier=64
|
||||
Whitespace=65
|
||||
Newline=66
|
||||
ElementFilter=49
|
||||
STEuqals=50
|
||||
STTouches=51
|
||||
STOverlaps=52
|
||||
STCrosses=53
|
||||
STContains=54
|
||||
STIntersects=55
|
||||
STWithin=56
|
||||
STDWithin=57
|
||||
STIsValid=58
|
||||
BooleanConstant=59
|
||||
IntegerConstant=60
|
||||
FloatingConstant=61
|
||||
Identifier=62
|
||||
Meta=63
|
||||
StringLiteral=64
|
||||
JSONIdentifier=65
|
||||
StructSubFieldIdentifier=66
|
||||
Whitespace=67
|
||||
Newline=68
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -90,4 +92,4 @@ Newline=66
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=62
|
||||
'$meta'=63
|
||||
|
||||
@ -7,18 +7,6 @@ type BasePlanVisitor struct {
|
||||
*antlr.BaseParseTreeVisitor
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitRandomSample(ctx *RandomSampleContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -27,22 +15,10 @@ func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitLogicalOr(ctx *LogicalOrContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitIsNotNull(ctx *IsNotNullContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMulDivMod(ctx *MulDivModContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitIdentifier(ctx *IdentifierContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -55,14 +31,6 @@ func (v *BasePlanVisitor) VisitLike(ctx *LikeContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitLogicalAnd(ctx *LogicalAndContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTemplateVariable(ctx *TemplateVariableContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitEquality(ctx *EqualityContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -71,14 +39,6 @@ func (v *BasePlanVisitor) VisitBoolean(ctx *BooleanContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTDWithin(ctx *STDWithinContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -87,54 +47,26 @@ func (v *BasePlanVisitor) VisitTimestamptzCompareForward(ctx *TimestamptzCompare
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTCrosses(ctx *STCrossesContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitOr(ctx *BitOrContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitEmptyArray(ctx *EmptyArrayContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitAddSub(ctx *AddSubContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTextMatch(ctx *TextMatchContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTTouches(ctx *STTouchesContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTContains(ctx *STContainsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTerm(ctx *TermContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -151,6 +83,94 @@ func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitElementFilter(ctx *ElementFilterContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitRandomSample(ctx *RandomSampleContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitLogicalOr(ctx *LogicalOrContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMulDivMod(ctx *MulDivModContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitLogicalAnd(ctx *LogicalAndContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTemplateVariable(ctx *TemplateVariableContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTDWithin(ctx *STDWithinContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTCrosses(ctx *STCrossesContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitOr(ctx *BitOrContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitAddSub(ctx *AddSubContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTextMatch(ctx *TextMatchContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTContains(ctx *STContainsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitUnary(ctx *UnaryContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -167,22 +187,10 @@ func (v *BasePlanVisitor) VisitJSONContainsAny(ctx *JSONContainsAnyContext) inte
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitExists(ctx *ExistsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTEuqals(ctx *STEuqalsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -191,11 +199,11 @@ func (v *BasePlanVisitor) VisitIsNull(ctx *IsNullContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitPower(ctx *PowerContext) interface{} {
|
||||
func (v *BasePlanVisitor) VisitStructSubField(ctx *StructSubFieldContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
|
||||
func (v *BasePlanVisitor) VisitPower(ctx *PowerContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -7,33 +7,15 @@ import "github.com/antlr4-go/antlr/v4"
|
||||
type PlanVisitor interface {
|
||||
antlr.ParseTreeVisitor
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONIdentifier.
|
||||
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#RandomSample.
|
||||
VisitRandomSample(ctx *RandomSampleContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Parens.
|
||||
VisitParens(ctx *ParensContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#String.
|
||||
VisitString(ctx *StringContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Floating.
|
||||
VisitFloating(ctx *FloatingContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONContainsAll.
|
||||
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#LogicalOr.
|
||||
VisitLogicalOr(ctx *LogicalOrContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#IsNotNull.
|
||||
VisitIsNotNull(ctx *IsNotNullContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MulDivMod.
|
||||
VisitMulDivMod(ctx *MulDivModContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Identifier.
|
||||
VisitIdentifier(ctx *IdentifierContext) interface{}
|
||||
|
||||
@ -43,66 +25,33 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#Like.
|
||||
VisitLike(ctx *LikeContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#LogicalAnd.
|
||||
VisitLogicalAnd(ctx *LogicalAndContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TemplateVariable.
|
||||
VisitTemplateVariable(ctx *TemplateVariableContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Equality.
|
||||
VisitEquality(ctx *EqualityContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Boolean.
|
||||
VisitBoolean(ctx *BooleanContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TimestamptzCompareReverse.
|
||||
VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STDWithin.
|
||||
VisitSTDWithin(ctx *STDWithinContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Shift.
|
||||
VisitShift(ctx *ShiftContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TimestamptzCompareForward.
|
||||
VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Call.
|
||||
VisitCall(ctx *CallContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STCrosses.
|
||||
VisitSTCrosses(ctx *STCrossesContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#ReverseRange.
|
||||
VisitReverseRange(ctx *ReverseRangeContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitOr.
|
||||
VisitBitOr(ctx *BitOrContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#EmptyArray.
|
||||
VisitEmptyArray(ctx *EmptyArrayContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#AddSub.
|
||||
VisitAddSub(ctx *AddSubContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#PhraseMatch.
|
||||
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Relational.
|
||||
VisitRelational(ctx *RelationalContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#ArrayLength.
|
||||
VisitArrayLength(ctx *ArrayLengthContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TextMatch.
|
||||
VisitTextMatch(ctx *TextMatchContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STTouches.
|
||||
VisitSTTouches(ctx *STTouchesContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STContains.
|
||||
VisitSTContains(ctx *STContainsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Term.
|
||||
VisitTerm(ctx *TermContext) interface{}
|
||||
|
||||
@ -115,6 +64,72 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#Range.
|
||||
VisitRange(ctx *RangeContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STIsValid.
|
||||
VisitSTIsValid(ctx *STIsValidContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitXor.
|
||||
VisitBitXor(ctx *BitXorContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#ElementFilter.
|
||||
VisitElementFilter(ctx *ElementFilterContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitAnd.
|
||||
VisitBitAnd(ctx *BitAndContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STOverlaps.
|
||||
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONIdentifier.
|
||||
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#RandomSample.
|
||||
VisitRandomSample(ctx *RandomSampleContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Parens.
|
||||
VisitParens(ctx *ParensContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONContainsAll.
|
||||
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#LogicalOr.
|
||||
VisitLogicalOr(ctx *LogicalOrContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MulDivMod.
|
||||
VisitMulDivMod(ctx *MulDivModContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#LogicalAnd.
|
||||
VisitLogicalAnd(ctx *LogicalAndContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TemplateVariable.
|
||||
VisitTemplateVariable(ctx *TemplateVariableContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TimestamptzCompareReverse.
|
||||
VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STDWithin.
|
||||
VisitSTDWithin(ctx *STDWithinContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Call.
|
||||
VisitCall(ctx *CallContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STCrosses.
|
||||
VisitSTCrosses(ctx *STCrossesContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitOr.
|
||||
VisitBitOr(ctx *BitOrContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#AddSub.
|
||||
VisitAddSub(ctx *AddSubContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Relational.
|
||||
VisitRelational(ctx *RelationalContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#TextMatch.
|
||||
VisitTextMatch(ctx *TextMatchContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STContains.
|
||||
VisitSTContains(ctx *STContainsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Unary.
|
||||
VisitUnary(ctx *UnaryContext) interface{}
|
||||
|
||||
@ -127,30 +142,21 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#JSONContainsAny.
|
||||
VisitJSONContainsAny(ctx *JSONContainsAnyContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STIsValid.
|
||||
VisitSTIsValid(ctx *STIsValidContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitXor.
|
||||
VisitBitXor(ctx *BitXorContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Exists.
|
||||
VisitExists(ctx *ExistsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#BitAnd.
|
||||
VisitBitAnd(ctx *BitAndContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STEuqals.
|
||||
VisitSTEuqals(ctx *STEuqalsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#IsNull.
|
||||
VisitIsNull(ctx *IsNullContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#StructSubField.
|
||||
VisitStructSubField(ctx *StructSubFieldContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Power.
|
||||
VisitPower(ctx *PowerContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STOverlaps.
|
||||
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#textMatchOption.
|
||||
VisitTextMatchOption(ctx *TextMatchOptionContext) interface{}
|
||||
}
|
||||
|
||||
@ -42,6 +42,8 @@ type ParserVisitor struct {
|
||||
parser.BasePlanVisitor
|
||||
schema *typeutil.SchemaHelper
|
||||
args *ParserVisitorArgs
|
||||
// currentStructArrayField stores the struct array field name when processing ElementFilter
|
||||
currentStructArrayField string
|
||||
}
|
||||
|
||||
func NewParserVisitor(schema *typeutil.SchemaHelper, args *ParserVisitorArgs) *ParserVisitor {
|
||||
@ -658,6 +660,10 @@ func isRandomSampleExpr(expr *ExprWithType) bool {
|
||||
return expr.expr.GetRandomSampleExpr() != nil
|
||||
}
|
||||
|
||||
func isElementFilterExpr(expr *ExprWithType) bool {
|
||||
return expr.expr.GetElementFilterExpr() != nil
|
||||
}
|
||||
|
||||
const EPSILON = 1e-10
|
||||
|
||||
func (v *ParserVisitor) VisitRandomSample(ctx *parser.RandomSampleContext) interface{} {
|
||||
@ -773,7 +779,47 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode) (*planpb.ColumnInfo, error) {
|
||||
func isValidStructSubField(tokenText string) bool {
|
||||
return len(tokenText) >= 4 && tokenText[:2] == "$[" && tokenText[len(tokenText)-1] == ']'
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) getColumnInfoFromStructSubField(tokenText string) (*planpb.ColumnInfo, error) {
|
||||
if !isValidStructSubField(tokenText) {
|
||||
return nil, fmt.Errorf("invalid struct sub-field syntax: %s", tokenText)
|
||||
}
|
||||
// Remove "$[" prefix and "]" suffix
|
||||
fieldName := tokenText[2 : len(tokenText)-1]
|
||||
|
||||
// Check if we're inside an ElementFilter context
|
||||
if v.currentStructArrayField == "" {
|
||||
return nil, fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
|
||||
}
|
||||
|
||||
// Construct full field name for struct array field
|
||||
fullFieldName := v.currentStructArrayField + "[" + fieldName + "]"
|
||||
// Get the struct array field info
|
||||
field, err := v.schema.GetFieldFromName(fullFieldName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
|
||||
}
|
||||
|
||||
// In element-level context, data_type should be the element type
|
||||
elementType := field.GetElementType()
|
||||
|
||||
return &planpb.ColumnInfo{
|
||||
FieldId: field.FieldID,
|
||||
DataType: elementType, // Use element type, not storage type
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
IsAutoID: field.AutoID,
|
||||
IsPartitionKey: field.IsPartitionKey,
|
||||
IsClusteringKey: field.IsClusteringKey,
|
||||
ElementType: elementType,
|
||||
Nullable: field.GetNullable(),
|
||||
IsElementLevel: true, // Mark as element-level access
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) getChildColumnInfo(identifier, child, structSubField antlr.TerminalNode) (*planpb.ColumnInfo, error) {
|
||||
if identifier != nil {
|
||||
childExpr, err := v.translateIdentifier(identifier.GetText())
|
||||
if err != nil {
|
||||
@ -782,6 +828,10 @@ func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode)
|
||||
return toColumnInfo(childExpr), nil
|
||||
}
|
||||
|
||||
if structSubField != nil {
|
||||
return v.getColumnInfoFromStructSubField(structSubField.GetText())
|
||||
}
|
||||
|
||||
return v.getColumnInfoFromJSONIdentifier(child.GetText())
|
||||
}
|
||||
|
||||
@ -812,7 +862,7 @@ func (v *ParserVisitor) VisitCall(ctx *parser.CallContext) interface{} {
|
||||
|
||||
// VisitRange translates expr to range plan.
|
||||
func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} {
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), ctx.StructSubFieldIdentifier())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -893,7 +943,7 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} {
|
||||
|
||||
// VisitReverseRange parses the expression like "1 > a > 0".
|
||||
func (v *ParserVisitor) VisitReverseRange(ctx *parser.ReverseRangeContext) interface{} {
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), ctx.StructSubFieldIdentifier())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1081,6 +1131,10 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{}
|
||||
return errors.New("random sample expression cannot be used in logical and expression")
|
||||
}
|
||||
|
||||
if isElementFilterExpr(leftExpr) {
|
||||
return errors.New("element filter expression can only be the last expression in the logical or expression")
|
||||
}
|
||||
|
||||
if !canBeExecuted(leftExpr) || !canBeExecuted(rightExpr) {
|
||||
return errors.New("'or' can only be used between boolean expressions")
|
||||
}
|
||||
@ -1133,6 +1187,10 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
|
||||
return errors.New("random sample expression can only be the last expression in the logical and expression")
|
||||
}
|
||||
|
||||
if isElementFilterExpr(leftExpr) {
|
||||
return errors.New("element filter expression can only be the last expression in the logical and expression")
|
||||
}
|
||||
|
||||
if !canBeExecuted(leftExpr) || !canBeExecuted(rightExpr) {
|
||||
return errors.New("'and' can only be used between boolean expressions")
|
||||
}
|
||||
@ -1146,6 +1204,15 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
|
||||
RandomSampleExpr: randomSampleExpr,
|
||||
},
|
||||
}
|
||||
} else if isElementFilterExpr(rightExpr) {
|
||||
// Similar to RandomSampleExpr, extract doc-level predicate
|
||||
elementFilterExpr := rightExpr.expr.GetElementFilterExpr()
|
||||
elementFilterExpr.Predicate = leftExpr.expr
|
||||
expr = &planpb.Expr{
|
||||
Expr: &planpb.Expr_ElementFilterExpr{
|
||||
ElementFilterExpr: elementFilterExpr,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
expr = &planpb.Expr{
|
||||
Expr: &planpb.Expr_BinaryExpr{
|
||||
@ -1410,7 +1477,7 @@ func (v *ParserVisitor) VisitEmptyArray(ctx *parser.EmptyArrayContext) interface
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) VisitIsNotNull(ctx *parser.IsNotNullContext) interface{} {
|
||||
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
|
||||
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1450,7 +1517,7 @@ func (v *ParserVisitor) VisitIsNotNull(ctx *parser.IsNotNullContext) interface{}
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) VisitIsNull(ctx *parser.IsNullContext) interface{} {
|
||||
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
|
||||
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1653,7 +1720,7 @@ func (v *ParserVisitor) VisitJSONContainsAny(ctx *parser.JSONContainsAnyContext)
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) VisitArrayLength(ctx *parser.ArrayLengthContext) interface{} {
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
|
||||
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -2182,3 +2249,90 @@ func validateAndExtractMinShouldMatch(minShouldMatchExpr interface{}) ([]*planpb
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// VisitElementFilter handles ElementFilter(structArrayField, elementExpr) syntax.
|
||||
func (v *ParserVisitor) VisitElementFilter(ctx *parser.ElementFilterContext) interface{} {
|
||||
// Check for nested ElementFilter - not allowed
|
||||
if v.currentStructArrayField != "" {
|
||||
return fmt.Errorf("nested ElementFilter is not supported, already inside ElementFilter for field: %s", v.currentStructArrayField)
|
||||
}
|
||||
|
||||
// Get struct array field name (first parameter)
|
||||
arrayFieldName := ctx.Identifier().GetText()
|
||||
|
||||
// Set current context for element expression parsing
|
||||
v.currentStructArrayField = arrayFieldName
|
||||
defer func() { v.currentStructArrayField = "" }()
|
||||
|
||||
elementExpr := ctx.Expr().Accept(v)
|
||||
if err := getError(elementExpr); err != nil {
|
||||
return fmt.Errorf("cannot parse element expression: %s, error: %s", ctx.Expr().GetText(), err)
|
||||
}
|
||||
|
||||
exprWithType := getExpr(elementExpr)
|
||||
if exprWithType == nil {
|
||||
return fmt.Errorf("invalid element expression: %s", ctx.Expr().GetText())
|
||||
}
|
||||
|
||||
// Build ElementFilterExpr proto
|
||||
return &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ElementFilterExpr{
|
||||
ElementFilterExpr: &planpb.ElementFilterExpr{
|
||||
ElementExpr: exprWithType.expr,
|
||||
StructName: arrayFieldName,
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Bool,
|
||||
}
|
||||
}
|
||||
|
||||
// VisitStructSubField handles $[fieldName] syntax within ElementFilter.
|
||||
func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) interface{} {
|
||||
// Extract the field name from $[fieldName]
|
||||
tokenText := ctx.StructSubFieldIdentifier().GetText()
|
||||
if !isValidStructSubField(tokenText) {
|
||||
return fmt.Errorf("invalid struct sub-field syntax: %s", tokenText)
|
||||
}
|
||||
// Remove "$[" prefix and "]" suffix
|
||||
fieldName := tokenText[2 : len(tokenText)-1]
|
||||
|
||||
// Check if we're inside an ElementFilter context
|
||||
if v.currentStructArrayField == "" {
|
||||
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
|
||||
}
|
||||
|
||||
// Construct full field name for struct array field
|
||||
fullFieldName := v.currentStructArrayField + "[" + fieldName + "]"
|
||||
// Get the struct array field info
|
||||
field, err := v.schema.GetFieldFromName(fullFieldName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
|
||||
}
|
||||
|
||||
// In element-level context, data_type should be the element type
|
||||
elementType := field.GetElementType()
|
||||
|
||||
return &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ColumnExpr{
|
||||
ColumnExpr: &planpb.ColumnExpr{
|
||||
Info: &planpb.ColumnInfo{
|
||||
FieldId: field.FieldID,
|
||||
DataType: elementType, // Use element type, not storage type
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
IsAutoID: field.AutoID,
|
||||
IsPartitionKey: field.IsPartitionKey,
|
||||
IsClusteringKey: field.IsClusteringKey,
|
||||
ElementType: elementType,
|
||||
Nullable: field.GetNullable(),
|
||||
IsElementLevel: true, // Mark as element-level access
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: elementType, // Expression evaluates to element type
|
||||
nodeDependent: true,
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,11 +48,27 @@ func newTestSchema(EnableDynamicField bool) *schemapb.CollectionSchema {
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
})
|
||||
|
||||
structArrayField := &schemapb.StructArrayFieldSchema{
|
||||
FieldID: 132, Name: "struct_array", Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 133, Name: "struct_array[sub_str]", IsPrimaryKey: false, Description: "sub struct array field for string",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 134, Name: "struct_array[sub_int]", IsPrimaryKey: false, Description: "sub struct array field for int",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Description: "schema for test used",
|
||||
AutoID: true,
|
||||
Fields: fields,
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{structArrayField},
|
||||
EnableDynamicField: EnableDynamicField,
|
||||
}
|
||||
}
|
||||
@ -2487,3 +2503,57 @@ func TestExpr_GISFunctionsInvalidParameterTypes(t *testing.T) {
|
||||
assertInvalidExpr(t, schema, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_ElementFilter(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Valid expressions
|
||||
validExprs := []string{
|
||||
`element_filter(struct_array, 2 > $[sub_int] > 1)`,
|
||||
`element_filter(struct_array, $[sub_int] > 1)`,
|
||||
`element_filter(struct_array, $[sub_int] == 100)`,
|
||||
`element_filter(struct_array, $[sub_int] >= 0)`,
|
||||
`element_filter(struct_array, $[sub_int] <= 1000)`,
|
||||
`element_filter(struct_array, $[sub_int] != 0)`,
|
||||
|
||||
`element_filter(struct_array, $[sub_str] == "1")`,
|
||||
`element_filter(struct_array, $[sub_str] != "")`,
|
||||
|
||||
`element_filter(struct_array, $[sub_str] == "1" || $[sub_int] > 1)`,
|
||||
`element_filter(struct_array, $[sub_str] == "1" && $[sub_int] > 1)`,
|
||||
`element_filter(struct_array, $[sub_int] > 0 && $[sub_int] < 100)`,
|
||||
|
||||
`element_filter(struct_array, ($[sub_int] > 0 && $[sub_int] < 100) || $[sub_str] == "default")`,
|
||||
`element_filter(struct_array, !($[sub_int] < 0))`,
|
||||
|
||||
`Int64Field > 0 && element_filter(struct_array, $[sub_int] > 1)`,
|
||||
}
|
||||
|
||||
for _, expr := range validExprs {
|
||||
assertValidExpr(t, helper, expr)
|
||||
}
|
||||
|
||||
// Invalid expressions
|
||||
invalidExprs := []string{
|
||||
`element_filter(struct_array, element_filter(struct_array, $[sub_int] > 1))`,
|
||||
`element_filter(struct_array, $[sub_int] > 1 && element_filter(struct_array, $[sub_str] == "1"))`,
|
||||
|
||||
`$[sub_int] > 1`,
|
||||
`Int64Field > 0 && $[sub_int] > 1`,
|
||||
|
||||
`element_filter(struct_array, $[non_existent_field] > 1)`,
|
||||
`element_filter(non_existent_array, $[sub_int] > 1)`,
|
||||
|
||||
`element_filter(struct_array)`, // missing element expression
|
||||
`element_filter()`, // missing all parameters
|
||||
|
||||
`element_filter(struct_array, $[sub_int] > 1) || element_filter(struct_array, $[sub_str] == "test")`,
|
||||
`element_filter(struct_array, $[sub_int] > 1) && Int64Field > 0`,
|
||||
}
|
||||
|
||||
for _, expr := range invalidExprs {
|
||||
assertInvalidExpr(t, helper, expr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,6 +150,18 @@ func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.S
|
||||
gpFieldBuilder.Add(groupByVal)
|
||||
typeutil.AppendPKs(ret.Results.Ids, pk)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
|
||||
// Handle ElementIndices if present
|
||||
if subData.ElementIndices != nil {
|
||||
if ret.Results.ElementIndices == nil {
|
||||
ret.Results.ElementIndices = &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
}
|
||||
}
|
||||
elemIdx := subData.ElementIndices.GetData()[innerIdx]
|
||||
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
||||
}
|
||||
|
||||
dataCount += 1
|
||||
}
|
||||
}
|
||||
@ -308,6 +320,18 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||
}
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
|
||||
// Handle ElementIndices if present
|
||||
if subResData.ElementIndices != nil {
|
||||
if ret.Results.ElementIndices == nil {
|
||||
ret.Results.ElementIndices = &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
}
|
||||
}
|
||||
elemIdx := subResData.ElementIndices.GetData()[groupEntity.resultIdx]
|
||||
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
||||
}
|
||||
|
||||
gpFieldBuilder.Add(groupVal)
|
||||
}
|
||||
}
|
||||
@ -436,6 +460,18 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||
}
|
||||
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
|
||||
// Handle ElementIndices if present
|
||||
if subSearchResultData[subSearchIdx].ElementIndices != nil {
|
||||
if ret.Results.ElementIndices == nil {
|
||||
ret.Results.ElementIndices = &schemapb.LongArray{
|
||||
Data: make([]int64, 0, limit),
|
||||
}
|
||||
}
|
||||
elemIdx := subSearchResultData[subSearchIdx].ElementIndices.GetData()[resultDataIdx]
|
||||
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
|
||||
}
|
||||
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
|
||||
@ -421,10 +421,22 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType)
|
||||
}
|
||||
} else if typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) {
|
||||
// TODO(SpadeA): adjust it when more metric types are supported. Especially, when different metric types
|
||||
// are supported for different element types.
|
||||
if !funcutil.SliceContain(indexparamcheck.EmbListMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
|
||||
if typeutil.IsDenseFloatVectorType(cit.fieldSchema.ElementType) {
|
||||
if !funcutil.SliceContain(indexparamcheck.FloatVectorMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with float element type does not support metric type: "+metricType)
|
||||
}
|
||||
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.ElementType) {
|
||||
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with binary element type does not support metric type: "+metricType)
|
||||
}
|
||||
} else if typeutil.IsIntVectorType(cit.fieldSchema.ElementType) {
|
||||
if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with int element type does not support metric type: "+metricType)
|
||||
}
|
||||
} else {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1229,37 +1229,6 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("metrics wrong for embedding list index", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "EmbListFloat",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := cit.parseIndexParams(context.TODO())
|
||||
assert.True(t, strings.Contains(err.Error(), "array of vector index does not support metric type: L2"))
|
||||
})
|
||||
|
||||
t.Run("metric type wrong", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
|
||||
@ -45,6 +45,20 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
||||
Topks: make([]int64, 0),
|
||||
}
|
||||
|
||||
// Check element-level consistency: all results must have ElementIndices or none
|
||||
hasElementIndices := searchResultData[0].ElementIndices != nil
|
||||
for i, data := range searchResultData {
|
||||
if (data.ElementIndices != nil) != hasElementIndices {
|
||||
return nil, fmt.Errorf("inconsistent element-level flag in search results: result[0] has ElementIndices=%v, but result[%d] has ElementIndices=%v",
|
||||
hasElementIndices, i, data.ElementIndices != nil)
|
||||
}
|
||||
}
|
||||
if hasElementIndices {
|
||||
ret.ElementIndices = &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
}
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
@ -76,6 +90,9 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
|
||||
ret.ElementIndices.Data = append(ret.ElementIndices.Data, searchResultData[sel].ElementIndices.Data[idx])
|
||||
}
|
||||
idSet[id] = struct{}{}
|
||||
j++
|
||||
} else {
|
||||
@ -127,6 +144,20 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
||||
Topks: make([]int64, 0),
|
||||
}
|
||||
|
||||
// Check element-level consistency: all results must have ElementIndices or none
|
||||
hasElementIndices := searchResultData[0].ElementIndices != nil
|
||||
for i, data := range searchResultData {
|
||||
if (data.ElementIndices != nil) != hasElementIndices {
|
||||
return nil, fmt.Errorf("inconsistent element-level flag in search results: result[0] has ElementIndices=%v, but result[%d] has ElementIndices=%v",
|
||||
hasElementIndices, i, data.ElementIndices != nil)
|
||||
}
|
||||
}
|
||||
if hasElementIndices {
|
||||
ret.ElementIndices = &schemapb.LongArray{
|
||||
Data: make([]int64, 0),
|
||||
}
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
groupByValIterator := make([]func(int) any, len(searchResultData))
|
||||
for i := range searchResultData {
|
||||
@ -180,6 +211,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
|
||||
ret.ElementIndices.Data = append(ret.ElementIndices.Data, searchResultData[sel].ElementIndices.Data[idx])
|
||||
}
|
||||
gpFieldBuilder.Add(groupByVal)
|
||||
groupByValueMap[groupByVal] += 1
|
||||
idSet[id] = struct{}{}
|
||||
|
||||
@ -2120,8 +2120,18 @@ func estimateLoadingResourceUsageOfSegment(schema *schemapb.CollectionSchema, lo
|
||||
}
|
||||
}
|
||||
|
||||
// per struct memory size, used to keep mapping between row id and element id
|
||||
var structArrayOffsetsSize uint64
|
||||
// PART 6: calculate size of struct array offsets
|
||||
// The memory size is 4 * row_count + 4 * total_element_count
|
||||
// We cannot easily get the element count, so we estimate it by the row count * 10
|
||||
rowCount := uint64(loadInfo.GetNumOfRows())
|
||||
for range len(schema.GetStructArrayFields()) {
|
||||
structArrayOffsetsSize += 4*rowCount + 4*rowCount*10
|
||||
}
|
||||
|
||||
return &ResourceUsage{
|
||||
MemorySize: segMemoryLoadingSize + indexMemorySize,
|
||||
MemorySize: segMemoryLoadingSize + indexMemorySize + structArrayOffsetsSize,
|
||||
DiskSize: segDiskLoadingSize,
|
||||
MmapFieldCount: mmapFieldCount,
|
||||
FieldGpuMemorySize: fieldGpuMemorySize,
|
||||
|
||||
@ -59,7 +59,21 @@ func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, elementType sch
|
||||
}
|
||||
} else if typeutil.IsArrayOfVectorType(dataType) {
|
||||
if !CheckStrByValues(params, Metric, EmbListMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], EmbListMetrics)
|
||||
if typeutil.IsDenseFloatVectorType(elementType) {
|
||||
if !CheckStrByValues(params, Metric, FloatVectorMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported for array of vector with float element type, supported: %v", params[Metric], FloatVectorMetrics)
|
||||
}
|
||||
} else if typeutil.IsBinaryVectorType(elementType) {
|
||||
if !CheckStrByValues(params, Metric, BinaryVectorMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported for array of vector with binary element type, supported: %v", params[Metric], BinaryVectorMetrics)
|
||||
}
|
||||
} else if typeutil.IsIntVectorType(elementType) {
|
||||
if !CheckStrByValues(params, Metric, IntVectorMetrics) {
|
||||
return fmt.Errorf("metric type %s not found or not supported for array of vector with int element type, supported: %v", params[Metric], IntVectorMetrics)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("metric type %s not found or not supported for array of vector, supported: %v", params[Metric], EmbListMetrics)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -98,6 +98,7 @@ message ColumnInfo {
|
||||
schema.DataType element_type = 7;
|
||||
bool is_clustering_key = 8;
|
||||
bool nullable = 9;
|
||||
bool is_element_level = 10;
|
||||
}
|
||||
|
||||
message ColumnExpr {
|
||||
@ -245,6 +246,12 @@ message RandomSampleExpr {
|
||||
Expr predicate = 2;
|
||||
}
|
||||
|
||||
message ElementFilterExpr {
|
||||
Expr element_expr = 1;
|
||||
string struct_name = 2;
|
||||
Expr predicate = 3;
|
||||
}
|
||||
|
||||
message AlwaysTrueExpr {}
|
||||
|
||||
message Interval {
|
||||
@ -285,6 +292,7 @@ message Expr {
|
||||
RandomSampleExpr random_sample_expr = 16;
|
||||
GISFunctionFilterExpr gisfunction_filter_expr = 17;
|
||||
TimestamptzArithCompareExpr timestamptz_arith_compare_expr = 18;
|
||||
ElementFilterExpr element_filter_expr = 19;
|
||||
};
|
||||
bool is_template = 20;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user