enhance: support nullable group by keys (#41313)

See #36264

---------

Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
Ted Xu 2025-04-18 10:08:34 +08:00 committed by GitHub
parent 62293cb582
commit d50781c8cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 144 additions and 112 deletions

View File

@ -176,10 +176,10 @@ struct SearchResult {
chunk_count, total_rows_until_chunk);
vector_iterators.emplace_back(vector_iterator);
}
auto kw_iterator = kw_iterators[i];
const auto& kw_iterator = kw_iterators[i];
vector_iterators[vec_iter_idx++]->AddIterator(kw_iterator);
}
for (auto vector_iter : vector_iterators) {
for (const auto& vector_iter : vector_iterators) {
vector_iter->seal();
}
this->vector_iterators_ = vector_iterators;

View File

@ -263,13 +263,13 @@ CalcPksSize(const PkType* data, size_t n) {
return size;
}
using GroupByValueType = std::variant<std::monostate,
int8_t,
int16_t,
int32_t,
int64_t,
bool,
std::string>;
using GroupByValueType = std::optional<std::variant<std::monostate,
int8_t,
int16_t,
int32_t,
int64_t,
bool,
std::string>>;
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
using NullExprType = proto::plan::NullExpr_NullOp;

View File

@ -180,7 +180,7 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
//2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following
//query and search possibly
std::vector<std::tuple<int64_t, float, T>> res;
std::vector<std::tuple<int64_t, float, std::optional<T>>> res;
while (iterator->HasNext() && !groupMap.IsGroupResEnough()) {
auto offset_dis_pair = iterator->Next();
AssertInfo(
@ -189,7 +189,7 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
"tells hasNext, terminate groupBy operation");
auto offset = offset_dis_pair.value().first;
auto dis = offset_dis_pair.value().second;
T row_data = data_getter.Get(offset);
std::optional<T> row_data = data_getter.Get(offset);
if (groupMap.Push(row_data)) {
res.emplace_back(offset, dis, row_data);
}

View File

@ -16,7 +16,9 @@
#pragma once
#include <optional>
#include "common/QueryInfo.h"
#include "common/Types.h"
#include "knowhere/index/index_node.h"
#include "segcore/SegmentInterface.h"
#include "segcore/SegmentGrowingImpl.h"
@ -31,25 +33,30 @@ namespace exec {
template <typename T>
class DataGetter {
public:
virtual T
virtual std::optional<T>
Get(int64_t idx) const = 0;
};
template <typename T>
class GrowingDataGetter : public DataGetter<T> {
public:
const segcore::ConcurrentVector<T>* growing_raw_data_;
GrowingDataGetter(const segcore::SegmentGrowingImpl& segment,
FieldId fieldId) {
growing_raw_data_ = segment.get_insert_record().get_data<T>(fieldId);
valid_data_ = segment.get_insert_record().is_valid_data_exist(fieldId)
? segment.get_insert_record().get_valid_data(fieldId)
: nullptr;
}
GrowingDataGetter(const GrowingDataGetter<T>& other)
: growing_raw_data_(other.growing_raw_data_) {
}
T
std::optional<T>
Get(int64_t idx) const {
if (valid_data_ && !valid_data_->is_valid(idx)) {
return std::nullopt;
}
if constexpr (std::is_same_v<std::string, T>) {
if (growing_raw_data_->is_mmap()) {
// when scalar data is mapped, it's needed to get the scalar data view and reconstruct string from the view
@ -58,6 +65,10 @@ class GrowingDataGetter : public DataGetter<T> {
}
return growing_raw_data_->operator[](idx);
}
protected:
const segcore::ConcurrentVector<T>* growing_raw_data_;
segcore::ThreadSafeValidDataPtr valid_data_;
};
template <typename T>
@ -69,6 +80,7 @@ class SealedDataGetter : public DataGetter<T> {
mutable std::unordered_map<int64_t, std::vector<std::string_view>>
str_view_map_;
mutable std::unordered_map<int64_t, FixedVector<bool>> valid_map_;
// Getting str_view from segment is cpu-costly, this map is to cache this view for performance
public:
SealedDataGetter(const segcore::SegmentSealed& segment, FieldId& field_id)
@ -84,7 +96,7 @@ class SealedDataGetter : public DataGetter<T> {
}
}
T
std::optional<T>
Get(int64_t idx) const {
if (from_data_) {
auto id_offset_pair = segment_.get_chunk_by_offset(field_id_, idx);
@ -92,22 +104,30 @@ class SealedDataGetter : public DataGetter<T> {
auto inner_offset = id_offset_pair.second;
if constexpr (std::is_same_v<T, std::string>) {
if (str_view_map_.find(chunk_id) == str_view_map_.end()) {
// for now, search_group_by does not handle null values
auto [str_chunk_view, _] =
auto [str_chunk_view, valid_data] =
segment_.chunk_view<std::string_view>(field_id_,
chunk_id);
valid_map_[chunk_id] = std::move(valid_data);
str_view_map_[chunk_id] = std::move(str_chunk_view);
}
auto& str_chunk_view = str_view_map_[chunk_id];
std::string_view str_val_view =
str_chunk_view.operator[](inner_offset);
auto valid_data = valid_map_[chunk_id];
if (!valid_data.empty()) {
if (!valid_map_[chunk_id][inner_offset]) {
return std::nullopt;
}
}
auto str_val_view = str_view_map_[chunk_id][inner_offset];
return std::string(str_val_view.data(), str_val_view.length());
} else {
Span<T> span = segment_.chunk_data<T>(field_id_, chunk_id);
if (span.valid_data() && !span.valid_data()[inner_offset]) {
return std::nullopt;
}
auto raw = span.operator[](inner_offset);
return raw;
}
} else {
// null is not supported for indexed fields
auto& chunk_index = segment_.chunk_scalar_index<T>(field_id_, 0);
auto raw = chunk_index.Reverse_Lookup(idx);
AssertInfo(raw.has_value(), "field data not found");
@ -160,7 +180,7 @@ GroupIteratorsByType(
template <typename T>
struct GroupByMap {
private:
std::unordered_map<T, int> group_map_{};
std::unordered_map<std::optional<T>, int> group_map_{};
int group_capacity_{0};
int group_size_{0};
int enough_group_count_{0};
@ -185,7 +205,7 @@ struct GroupByMap {
return enough;
}
bool
Push(const T& t) {
Push(const std::optional<T>& t) {
if (group_map_.size() >= group_capacity_ &&
group_map_.find(t) == group_map_.end()) {
return false;
@ -211,7 +231,6 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t group_size,
bool strict_group_size,
const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets,
std::vector<float>& distances,
const knowhere::MetricType& metrics_type);

View File

@ -14,6 +14,8 @@
//
#include "ReduceUtils.h"
#include "google/protobuf/repeated_field.h"
#include "pb/schema.pb.h"
namespace milvus::segcore {
@ -26,9 +28,12 @@ AssembleGroupByValues(
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
auto group_by_values_field =
std::make_unique<milvus::proto::schema::ScalarField>();
auto valid_data =
std::make_unique<google::protobuf::RepeatedField<bool>>();
valid_data->Resize(group_by_vals.size(), true);
auto group_by_field =
plan->schema_.operator[](group_by_field_id.value());
DataType group_by_data_type = group_by_field.get_data_type();
auto group_by_data_type = group_by_field.get_data_type();
int group_by_val_size = group_by_vals.size();
switch (group_by_data_type) {
@ -36,8 +41,13 @@ AssembleGroupByValues(
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int8_t val = std::get<int8_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
if (group_by_vals[idx].has_value()) {
int8_t val =
std::get<int8_t>(group_by_vals[idx].value());
field_data->mutable_data()->Set(idx, val);
} else {
valid_data->Set(idx, false);
}
}
break;
}
@ -45,8 +55,13 @@ AssembleGroupByValues(
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int16_t val = std::get<int16_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
if (group_by_vals[idx].has_value()) {
int16_t val =
std::get<int16_t>(group_by_vals[idx].value());
field_data->mutable_data()->Set(idx, val);
} else {
valid_data->Set(idx, false);
}
}
break;
}
@ -54,8 +69,13 @@ AssembleGroupByValues(
auto field_data = group_by_values_field->mutable_int_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int32_t val = std::get<int32_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
if (group_by_vals[idx].has_value()) {
int32_t val =
std::get<int32_t>(group_by_vals[idx].value());
field_data->mutable_data()->Set(idx, val);
} else {
valid_data->Set(idx, false);
}
}
break;
}
@ -63,8 +83,13 @@ AssembleGroupByValues(
auto field_data = group_by_values_field->mutable_long_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
int64_t val = std::get<int64_t>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
if (group_by_vals[idx].has_value()) {
int64_t val =
std::get<int64_t>(group_by_vals[idx].value());
field_data->mutable_data()->Set(idx, val);
} else {
valid_data->Set(idx, false);
}
}
break;
}
@ -72,17 +97,25 @@ AssembleGroupByValues(
auto field_data = group_by_values_field->mutable_bool_data();
field_data->mutable_data()->Resize(group_by_val_size, 0);
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
bool val = std::get<bool>(group_by_vals[idx]);
field_data->mutable_data()->Set(idx, val);
if (group_by_vals[idx].has_value()) {
bool val = std::get<bool>(group_by_vals[idx].value());
field_data->mutable_data()->Set(idx, val);
} else {
valid_data->Set(idx, false);
}
}
break;
}
case DataType::VARCHAR: {
auto field_data = group_by_values_field->mutable_string_data();
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
std::string val =
std::move(std::get<std::string>(group_by_vals[idx]));
*(field_data->mutable_data()->Add()) = val;
if (group_by_vals[idx].has_value()) {
std::string val =
std::get<std::string>(group_by_vals[idx].value());
*(field_data->mutable_data()->Add()) = val;
} else {
valid_data->Set(idx, false);
}
}
break;
}
@ -94,11 +127,13 @@ AssembleGroupByValues(
}
}
search_result->mutable_group_by_field_value()->set_type(
auto group_by_field_value =
search_result->mutable_group_by_field_value();
group_by_field_value->set_type(
milvus::proto::schema::DataType(group_by_data_type));
search_result->mutable_group_by_field_value()
->mutable_scalars()
->MergeFrom(*group_by_values_field.get());
group_by_field_value->mutable_valid_data()->MergeFrom(*valid_data);
group_by_field_value->mutable_scalars()->MergeFrom(
*group_by_values_field.get());
return;
}
}

View File

@ -64,16 +64,15 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan,
plan->schema_.get_dynamic_field_id().value() == field_id &&
!plan->target_dynamic_fields_.empty()) {
auto& target_dynamic_fields = plan->target_dynamic_fields_;
field_data = std::move(bulk_subscript(field_id,
results.seg_offsets_.data(),
size,
target_dynamic_fields));
field_data = bulk_subscript(field_id,
results.seg_offsets_.data(),
size,
target_dynamic_fields);
} else if (!is_field_exist(field_id)) {
field_data =
std::move(bulk_subscript_not_exist_field(field_meta, size));
field_data = bulk_subscript_not_exist_field(field_meta, size);
} else {
field_data = std::move(
bulk_subscript(field_id, results.seg_offsets_.data(), size));
field_data =
bulk_subscript(field_id, results.seg_offsets_.data(), size);
}
results.output_fields_data_[field_id] = std::move(field_data);
}

View File

@ -251,7 +251,8 @@ class CachedSearchIteratorTest
size_t offset = 0;
for (size_t i = 0; i < num_chunks_; ++i) {
const size_t rows = std::min(nb_ - offset, kSizePerChunk);
const size_t rows =
std::min(static_cast<size_t>(nb_ - offset), kSizePerChunk);
const size_t buf_size = rows * dim_ * sizeof(float);
auto& chunk_data = column_data_[i];
chunk_data.resize(buf_size);

View File

@ -1408,7 +1408,7 @@ TEST_P(ExprTest, TestUnaryRangeJson) {
{
struct Testcase {
int64_t val;
double val;
std::vector<std::string> nested_path;
};
std::vector<Testcase> testcases{{1.1, {"double"}},

View File

@ -104,7 +104,7 @@ TEST(GroupBY, SealedIndex) {
//2. load raw data
auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) {
for (const auto& field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
@ -167,8 +167,8 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<int8_t, int> i8_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int8_t>(group_by_values[i])) {
int8_t g_val = std::get<int8_t>(group_by_values[i]);
if (std::holds_alternative<int8_t>(group_by_values[i].value())) {
int8_t g_val = std::get<int8_t>(group_by_values[i].value());
i8_map[g_val] += 1;
ASSERT_TRUE(i8_map[g_val] <= group_size);
//for every group, the number of hits should not exceed group_size
@ -220,8 +220,8 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<int16_t, int> i16_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int16_t>(group_by_values[i])) {
int16_t g_val = std::get<int16_t>(group_by_values[i]);
if (std::holds_alternative<int16_t>(group_by_values[i].value())) {
int16_t g_val = std::get<int16_t>(group_by_values[i].value());
i16_map[g_val] += 1;
ASSERT_TRUE(i16_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
@ -270,8 +270,8 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<int32_t, int> i32_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int32_t>(group_by_values[i])) {
int16_t g_val = std::get<int32_t>(group_by_values[i]);
if (std::holds_alternative<int32_t>(group_by_values[i].value())) {
int16_t g_val = std::get<int32_t>(group_by_values[i].value());
i32_map[g_val] += 1;
ASSERT_TRUE(i32_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
@ -320,8 +320,8 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<int64_t, int> i64_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int64_t>(group_by_values[i])) {
int16_t g_val = std::get<int64_t>(group_by_values[i]);
if (std::holds_alternative<int64_t>(group_by_values[i].value())) {
int16_t g_val = std::get<int64_t>(group_by_values[i].value());
i64_map[g_val] += 1;
ASSERT_TRUE(i64_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
@ -368,9 +368,10 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<std::string, int> strs_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<std::string>(group_by_values[i])) {
std::string g_val =
std::move(std::get<std::string>(group_by_values[i]));
if (std::holds_alternative<std::string>(
group_by_values[i].value())) {
std::string g_val = std::move(
std::get<std::string>(group_by_values[i].value()));
strs_map[g_val] += 1;
ASSERT_TRUE(strs_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
@ -420,8 +421,8 @@ TEST(GroupBY, SealedIndex) {
std::unordered_map<bool, int> bools_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<bool>(group_by_values[i])) {
bool g_val = std::get<bool>(group_by_values[i]);
if (std::holds_alternative<bool>(group_by_values[i].value())) {
bool g_val = std::get<bool>(group_by_values[i].value());
bools_map[g_val] += 1;
ASSERT_TRUE(bools_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
@ -445,7 +446,7 @@ TEST(GroupBY, SealedData) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
auto int8_fid = schema->AddDebugField("int8", DataType::INT8, true);
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
@ -456,9 +457,9 @@ TEST(GroupBY, SealedData) {
size_t N = 100;
//2. load raw data
auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
auto raw_data = DataGen(schema, N, 42, 0, 20, 10, false, false);
auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) {
for (auto&& field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
@ -503,26 +504,29 @@ TEST(GroupBY, SealedData) {
auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size();
//as the repeated is 8, so there will be 13 groups and enough 10 * 5 = 50 results
ASSERT_EQ(50, size);
// groups are: (0, 1, 2, 3, 4, null), counts are: (10, 10, 10, 10 ,10, 50)
ASSERT_EQ(30, size);
std::unordered_map<int8_t, int> i8_map;
float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int8_t>(group_by_values[i])) {
int8_t g_val = std::get<int8_t>(group_by_values[i]);
if (group_by_values[i].has_value()) {
int8_t g_val = std::get<int8_t>(group_by_values[i].value());
i8_map[g_val] += 1;
ASSERT_TRUE(i8_map[g_val] <= group_size);
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(
lastDistance <=
distance); //distance should be decreased as metrics_type is L2
lastDistance = distance;
} else {
i8_map[-1] += 1;
}
auto distance = search_result->distances_.at(i);
ASSERT_TRUE(
lastDistance <=
distance); //distance should be decreased as metrics_type is L2
lastDistance = distance;
}
ASSERT_TRUE(i8_map.size() == topK);
ASSERT_EQ(i8_map.size(), 6);
for (const auto& it : i8_map) {
ASSERT_TRUE(it.second == group_size);
ASSERT_TRUE(it.second == group_size)
<< "unexpected count on group " << it.first;
}
}
}
@ -537,7 +541,7 @@ TEST(GroupBY, Reduce) {
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true);
auto fp16_fid = schema->AddDebugField(
"fakevec_fp16", DataType::VECTOR_FLOAT16, dim, knowhere::metric::L2);
auto bf16_fid = schema->AddDebugField(
@ -572,7 +576,7 @@ TEST(GroupBY, Reduce) {
prepareSegmentSystemFieldData(segment1, N, raw_data1);
//load segment2 raw data
for (auto field_data : raw_data2.raw_->fields_data()) {
for (auto&& field_data : raw_data2.raw_->fields_data()) {
int64_t field_id = field_data.field_id();
auto info = FieldDataInfo(field_data.field_id(), N);
auto field_meta = fields.at(FieldId(field_id));
@ -735,8 +739,8 @@ TEST(GroupBY, GrowingRawData) {
std::unordered_set<int32_t> i32_set;
float lastDistance = 0.0;
for (int j = 0; j < expected_group_count; j++) {
if (std::holds_alternative<int32_t>(group_by_values[idx])) {
int32_t g_val = std::get<int32_t>(group_by_values[idx]);
if (std::holds_alternative<int32_t>(group_by_values[idx].value())) {
int32_t g_val = std::get<int32_t>(group_by_values[idx].value());
ASSERT_FALSE(
i32_set.count(g_val) >
0); //as the group_size is 1, there should not be any duplication for group_by value
@ -833,8 +837,8 @@ TEST(GroupBY, GrowingIndex) {
std::unordered_map<int32_t, int> i32_map;
float lastDistance = 0.0;
for (int j = 0; j < expected_group_count * group_size; j++) {
if (std::holds_alternative<int32_t>(group_by_values[idx])) {
int32_t g_val = std::get<int32_t>(group_by_values[idx]);
if (std::holds_alternative<int32_t>(group_by_values[idx].value())) {
int32_t g_val = std::get<int32_t>(group_by_values[idx].value());
i32_map[g_val] += 1;
ASSERT_TRUE(i32_map[g_val] <= group_size);
auto distance = search_result->distances_.at(idx);

View File

@ -439,10 +439,6 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
fields := schema.GetFields()
for _, field := range fields {
if field.Name == groupByFieldName {
if field.GetNullable() {
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
return ret
}
groupByFieldId = field.FieldID
break
}

View File

@ -2797,28 +2797,6 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Fields: fields,
}
searchInfo := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, searchInfo.planInfo)
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
normalParam = getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "string_field",
})
fields = make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(102),
Name: "null_field",
Nullable: true,
})
schema = &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo = parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, searchInfo.planInfo)
assert.NoError(t, searchInfo.parseError)
})