mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: support nullable group by keys (#41313)
See #36264 --------- Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
parent
62293cb582
commit
d50781c8cc
@ -176,10 +176,10 @@ struct SearchResult {
|
||||
chunk_count, total_rows_until_chunk);
|
||||
vector_iterators.emplace_back(vector_iterator);
|
||||
}
|
||||
auto kw_iterator = kw_iterators[i];
|
||||
const auto& kw_iterator = kw_iterators[i];
|
||||
vector_iterators[vec_iter_idx++]->AddIterator(kw_iterator);
|
||||
}
|
||||
for (auto vector_iter : vector_iterators) {
|
||||
for (const auto& vector_iter : vector_iterators) {
|
||||
vector_iter->seal();
|
||||
}
|
||||
this->vector_iterators_ = vector_iterators;
|
||||
|
||||
@ -263,13 +263,13 @@ CalcPksSize(const PkType* data, size_t n) {
|
||||
return size;
|
||||
}
|
||||
|
||||
using GroupByValueType = std::variant<std::monostate,
|
||||
int8_t,
|
||||
int16_t,
|
||||
int32_t,
|
||||
int64_t,
|
||||
bool,
|
||||
std::string>;
|
||||
using GroupByValueType = std::optional<std::variant<std::monostate,
|
||||
int8_t,
|
||||
int16_t,
|
||||
int32_t,
|
||||
int64_t,
|
||||
bool,
|
||||
std::string>>;
|
||||
using ContainsType = proto::plan::JSONContainsExpr_JSONOp;
|
||||
using NullExprType = proto::plan::NullExpr_NullOp;
|
||||
|
||||
|
||||
@ -180,7 +180,7 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||
//2. do iteration until fill the whole map or run out of all data
|
||||
//note it may enumerate all data inside a segment and can block following
|
||||
//query and search possibly
|
||||
std::vector<std::tuple<int64_t, float, T>> res;
|
||||
std::vector<std::tuple<int64_t, float, std::optional<T>>> res;
|
||||
while (iterator->HasNext() && !groupMap.IsGroupResEnough()) {
|
||||
auto offset_dis_pair = iterator->Next();
|
||||
AssertInfo(
|
||||
@ -189,7 +189,7 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||
"tells hasNext, terminate groupBy operation");
|
||||
auto offset = offset_dis_pair.value().first;
|
||||
auto dis = offset_dis_pair.value().second;
|
||||
T row_data = data_getter.Get(offset);
|
||||
std::optional<T> row_data = data_getter.Get(offset);
|
||||
if (groupMap.Push(row_data)) {
|
||||
res.emplace_back(offset, dis, row_data);
|
||||
}
|
||||
|
||||
@ -16,7 +16,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/Types.h"
|
||||
#include "knowhere/index/index_node.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
@ -31,25 +33,30 @@ namespace exec {
|
||||
template <typename T>
|
||||
class DataGetter {
|
||||
public:
|
||||
virtual T
|
||||
virtual std::optional<T>
|
||||
Get(int64_t idx) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class GrowingDataGetter : public DataGetter<T> {
|
||||
public:
|
||||
const segcore::ConcurrentVector<T>* growing_raw_data_;
|
||||
GrowingDataGetter(const segcore::SegmentGrowingImpl& segment,
|
||||
FieldId fieldId) {
|
||||
growing_raw_data_ = segment.get_insert_record().get_data<T>(fieldId);
|
||||
valid_data_ = segment.get_insert_record().is_valid_data_exist(fieldId)
|
||||
? segment.get_insert_record().get_valid_data(fieldId)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
GrowingDataGetter(const GrowingDataGetter<T>& other)
|
||||
: growing_raw_data_(other.growing_raw_data_) {
|
||||
}
|
||||
|
||||
T
|
||||
std::optional<T>
|
||||
Get(int64_t idx) const {
|
||||
if (valid_data_ && !valid_data_->is_valid(idx)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if constexpr (std::is_same_v<std::string, T>) {
|
||||
if (growing_raw_data_->is_mmap()) {
|
||||
// when scalar data is mapped, it's needed to get the scalar data view and reconstruct string from the view
|
||||
@ -58,6 +65,10 @@ class GrowingDataGetter : public DataGetter<T> {
|
||||
}
|
||||
return growing_raw_data_->operator[](idx);
|
||||
}
|
||||
|
||||
protected:
|
||||
const segcore::ConcurrentVector<T>* growing_raw_data_;
|
||||
segcore::ThreadSafeValidDataPtr valid_data_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -69,6 +80,7 @@ class SealedDataGetter : public DataGetter<T> {
|
||||
|
||||
mutable std::unordered_map<int64_t, std::vector<std::string_view>>
|
||||
str_view_map_;
|
||||
mutable std::unordered_map<int64_t, FixedVector<bool>> valid_map_;
|
||||
// Getting str_view from segment is cpu-costly, this map is to cache this view for performance
|
||||
public:
|
||||
SealedDataGetter(const segcore::SegmentSealed& segment, FieldId& field_id)
|
||||
@ -84,7 +96,7 @@ class SealedDataGetter : public DataGetter<T> {
|
||||
}
|
||||
}
|
||||
|
||||
T
|
||||
std::optional<T>
|
||||
Get(int64_t idx) const {
|
||||
if (from_data_) {
|
||||
auto id_offset_pair = segment_.get_chunk_by_offset(field_id_, idx);
|
||||
@ -92,22 +104,30 @@ class SealedDataGetter : public DataGetter<T> {
|
||||
auto inner_offset = id_offset_pair.second;
|
||||
if constexpr (std::is_same_v<T, std::string>) {
|
||||
if (str_view_map_.find(chunk_id) == str_view_map_.end()) {
|
||||
// for now, search_group_by does not handle null values
|
||||
auto [str_chunk_view, _] =
|
||||
auto [str_chunk_view, valid_data] =
|
||||
segment_.chunk_view<std::string_view>(field_id_,
|
||||
chunk_id);
|
||||
valid_map_[chunk_id] = std::move(valid_data);
|
||||
str_view_map_[chunk_id] = std::move(str_chunk_view);
|
||||
}
|
||||
auto& str_chunk_view = str_view_map_[chunk_id];
|
||||
std::string_view str_val_view =
|
||||
str_chunk_view.operator[](inner_offset);
|
||||
auto valid_data = valid_map_[chunk_id];
|
||||
if (!valid_data.empty()) {
|
||||
if (!valid_map_[chunk_id][inner_offset]) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
auto str_val_view = str_view_map_[chunk_id][inner_offset];
|
||||
return std::string(str_val_view.data(), str_val_view.length());
|
||||
} else {
|
||||
Span<T> span = segment_.chunk_data<T>(field_id_, chunk_id);
|
||||
if (span.valid_data() && !span.valid_data()[inner_offset]) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto raw = span.operator[](inner_offset);
|
||||
return raw;
|
||||
}
|
||||
} else {
|
||||
// null is not supported for indexed fields
|
||||
auto& chunk_index = segment_.chunk_scalar_index<T>(field_id_, 0);
|
||||
auto raw = chunk_index.Reverse_Lookup(idx);
|
||||
AssertInfo(raw.has_value(), "field data not found");
|
||||
@ -160,7 +180,7 @@ GroupIteratorsByType(
|
||||
template <typename T>
|
||||
struct GroupByMap {
|
||||
private:
|
||||
std::unordered_map<T, int> group_map_{};
|
||||
std::unordered_map<std::optional<T>, int> group_map_{};
|
||||
int group_capacity_{0};
|
||||
int group_size_{0};
|
||||
int enough_group_count_{0};
|
||||
@ -185,7 +205,7 @@ struct GroupByMap {
|
||||
return enough;
|
||||
}
|
||||
bool
|
||||
Push(const T& t) {
|
||||
Push(const std::optional<T>& t) {
|
||||
if (group_map_.size() >= group_capacity_ &&
|
||||
group_map_.find(t) == group_map_.end()) {
|
||||
return false;
|
||||
@ -211,7 +231,6 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||
int64_t group_size,
|
||||
bool strict_group_size,
|
||||
const DataGetter<T>& data_getter,
|
||||
std::vector<GroupByValueType>& group_by_values,
|
||||
std::vector<int64_t>& offsets,
|
||||
std::vector<float>& distances,
|
||||
const knowhere::MetricType& metrics_type);
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
//
|
||||
|
||||
#include "ReduceUtils.h"
|
||||
#include "google/protobuf/repeated_field.h"
|
||||
#include "pb/schema.pb.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
@ -26,9 +28,12 @@ AssembleGroupByValues(
|
||||
if (group_by_field_id.has_value() && group_by_vals.size() > 0) {
|
||||
auto group_by_values_field =
|
||||
std::make_unique<milvus::proto::schema::ScalarField>();
|
||||
auto valid_data =
|
||||
std::make_unique<google::protobuf::RepeatedField<bool>>();
|
||||
valid_data->Resize(group_by_vals.size(), true);
|
||||
auto group_by_field =
|
||||
plan->schema_.operator[](group_by_field_id.value());
|
||||
DataType group_by_data_type = group_by_field.get_data_type();
|
||||
auto group_by_data_type = group_by_field.get_data_type();
|
||||
|
||||
int group_by_val_size = group_by_vals.size();
|
||||
switch (group_by_data_type) {
|
||||
@ -36,8 +41,13 @@ AssembleGroupByValues(
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int8_t val = std::get<int8_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
int8_t val =
|
||||
std::get<int8_t>(group_by_vals[idx].value());
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -45,8 +55,13 @@ AssembleGroupByValues(
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int16_t val = std::get<int16_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
int16_t val =
|
||||
std::get<int16_t>(group_by_vals[idx].value());
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -54,8 +69,13 @@ AssembleGroupByValues(
|
||||
auto field_data = group_by_values_field->mutable_int_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int32_t val = std::get<int32_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
int32_t val =
|
||||
std::get<int32_t>(group_by_vals[idx].value());
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -63,8 +83,13 @@ AssembleGroupByValues(
|
||||
auto field_data = group_by_values_field->mutable_long_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
int64_t val = std::get<int64_t>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
int64_t val =
|
||||
std::get<int64_t>(group_by_vals[idx].value());
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -72,17 +97,25 @@ AssembleGroupByValues(
|
||||
auto field_data = group_by_values_field->mutable_bool_data();
|
||||
field_data->mutable_data()->Resize(group_by_val_size, 0);
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
bool val = std::get<bool>(group_by_vals[idx]);
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
bool val = std::get<bool>(group_by_vals[idx].value());
|
||||
field_data->mutable_data()->Set(idx, val);
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DataType::VARCHAR: {
|
||||
auto field_data = group_by_values_field->mutable_string_data();
|
||||
for (std::size_t idx = 0; idx < group_by_val_size; idx++) {
|
||||
std::string val =
|
||||
std::move(std::get<std::string>(group_by_vals[idx]));
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
if (group_by_vals[idx].has_value()) {
|
||||
std::string val =
|
||||
std::get<std::string>(group_by_vals[idx].value());
|
||||
*(field_data->mutable_data()->Add()) = val;
|
||||
} else {
|
||||
valid_data->Set(idx, false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -94,11 +127,13 @@ AssembleGroupByValues(
|
||||
}
|
||||
}
|
||||
|
||||
search_result->mutable_group_by_field_value()->set_type(
|
||||
auto group_by_field_value =
|
||||
search_result->mutable_group_by_field_value();
|
||||
group_by_field_value->set_type(
|
||||
milvus::proto::schema::DataType(group_by_data_type));
|
||||
search_result->mutable_group_by_field_value()
|
||||
->mutable_scalars()
|
||||
->MergeFrom(*group_by_values_field.get());
|
||||
group_by_field_value->mutable_valid_data()->MergeFrom(*valid_data);
|
||||
group_by_field_value->mutable_scalars()->MergeFrom(
|
||||
*group_by_values_field.get());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,16 +64,15 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan,
|
||||
plan->schema_.get_dynamic_field_id().value() == field_id &&
|
||||
!plan->target_dynamic_fields_.empty()) {
|
||||
auto& target_dynamic_fields = plan->target_dynamic_fields_;
|
||||
field_data = std::move(bulk_subscript(field_id,
|
||||
results.seg_offsets_.data(),
|
||||
size,
|
||||
target_dynamic_fields));
|
||||
field_data = bulk_subscript(field_id,
|
||||
results.seg_offsets_.data(),
|
||||
size,
|
||||
target_dynamic_fields);
|
||||
} else if (!is_field_exist(field_id)) {
|
||||
field_data =
|
||||
std::move(bulk_subscript_not_exist_field(field_meta, size));
|
||||
field_data = bulk_subscript_not_exist_field(field_meta, size);
|
||||
} else {
|
||||
field_data = std::move(
|
||||
bulk_subscript(field_id, results.seg_offsets_.data(), size));
|
||||
field_data =
|
||||
bulk_subscript(field_id, results.seg_offsets_.data(), size);
|
||||
}
|
||||
results.output_fields_data_[field_id] = std::move(field_data);
|
||||
}
|
||||
|
||||
@ -251,7 +251,8 @@ class CachedSearchIteratorTest
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < num_chunks_; ++i) {
|
||||
const size_t rows = std::min(nb_ - offset, kSizePerChunk);
|
||||
const size_t rows =
|
||||
std::min(static_cast<size_t>(nb_ - offset), kSizePerChunk);
|
||||
const size_t buf_size = rows * dim_ * sizeof(float);
|
||||
auto& chunk_data = column_data_[i];
|
||||
chunk_data.resize(buf_size);
|
||||
|
||||
@ -1408,7 +1408,7 @@ TEST_P(ExprTest, TestUnaryRangeJson) {
|
||||
|
||||
{
|
||||
struct Testcase {
|
||||
int64_t val;
|
||||
double val;
|
||||
std::vector<std::string> nested_path;
|
||||
};
|
||||
std::vector<Testcase> testcases{{1.1, {"double"}},
|
||||
|
||||
@ -104,7 +104,7 @@ TEST(GroupBY, SealedIndex) {
|
||||
//2. load raw data
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
|
||||
auto fields = schema->get_fields();
|
||||
for (auto field_data : raw_data.raw_->fields_data()) {
|
||||
for (const auto& field_data : raw_data.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
@ -167,8 +167,8 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<int8_t, int> i8_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<int8_t>(group_by_values[i])) {
|
||||
int8_t g_val = std::get<int8_t>(group_by_values[i]);
|
||||
if (std::holds_alternative<int8_t>(group_by_values[i].value())) {
|
||||
int8_t g_val = std::get<int8_t>(group_by_values[i].value());
|
||||
i8_map[g_val] += 1;
|
||||
ASSERT_TRUE(i8_map[g_val] <= group_size);
|
||||
//for every group, the number of hits should not exceed group_size
|
||||
@ -220,8 +220,8 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<int16_t, int> i16_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<int16_t>(group_by_values[i])) {
|
||||
int16_t g_val = std::get<int16_t>(group_by_values[i]);
|
||||
if (std::holds_alternative<int16_t>(group_by_values[i].value())) {
|
||||
int16_t g_val = std::get<int16_t>(group_by_values[i].value());
|
||||
i16_map[g_val] += 1;
|
||||
ASSERT_TRUE(i16_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
@ -270,8 +270,8 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<int32_t, int> i32_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<int32_t>(group_by_values[i])) {
|
||||
int16_t g_val = std::get<int32_t>(group_by_values[i]);
|
||||
if (std::holds_alternative<int32_t>(group_by_values[i].value())) {
|
||||
int16_t g_val = std::get<int32_t>(group_by_values[i].value());
|
||||
i32_map[g_val] += 1;
|
||||
ASSERT_TRUE(i32_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
@ -320,8 +320,8 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<int64_t, int> i64_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<int64_t>(group_by_values[i])) {
|
||||
int16_t g_val = std::get<int64_t>(group_by_values[i]);
|
||||
if (std::holds_alternative<int64_t>(group_by_values[i].value())) {
|
||||
int16_t g_val = std::get<int64_t>(group_by_values[i].value());
|
||||
i64_map[g_val] += 1;
|
||||
ASSERT_TRUE(i64_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
@ -368,9 +368,10 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<std::string, int> strs_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<std::string>(group_by_values[i])) {
|
||||
std::string g_val =
|
||||
std::move(std::get<std::string>(group_by_values[i]));
|
||||
if (std::holds_alternative<std::string>(
|
||||
group_by_values[i].value())) {
|
||||
std::string g_val = std::move(
|
||||
std::get<std::string>(group_by_values[i].value()));
|
||||
strs_map[g_val] += 1;
|
||||
ASSERT_TRUE(strs_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
@ -420,8 +421,8 @@ TEST(GroupBY, SealedIndex) {
|
||||
std::unordered_map<bool, int> bools_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<bool>(group_by_values[i])) {
|
||||
bool g_val = std::get<bool>(group_by_values[i]);
|
||||
if (std::holds_alternative<bool>(group_by_values[i].value())) {
|
||||
bool g_val = std::get<bool>(group_by_values[i].value());
|
||||
bools_map[g_val] += 1;
|
||||
ASSERT_TRUE(bools_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
@ -445,7 +446,7 @@ TEST(GroupBY, SealedData) {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
|
||||
auto int8_fid = schema->AddDebugField("int8", DataType::INT8);
|
||||
auto int8_fid = schema->AddDebugField("int8", DataType::INT8, true);
|
||||
auto int16_fid = schema->AddDebugField("int16", DataType::INT16);
|
||||
auto int32_fid = schema->AddDebugField("int32", DataType::INT32);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
@ -456,9 +457,9 @@ TEST(GroupBY, SealedData) {
|
||||
size_t N = 100;
|
||||
|
||||
//2. load raw data
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
|
||||
auto raw_data = DataGen(schema, N, 42, 0, 20, 10, false, false);
|
||||
auto fields = schema->get_fields();
|
||||
for (auto field_data : raw_data.raw_->fields_data()) {
|
||||
for (auto&& field_data : raw_data.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
@ -503,26 +504,29 @@ TEST(GroupBY, SealedData) {
|
||||
|
||||
auto& group_by_values = search_result->group_by_values_.value();
|
||||
int size = group_by_values.size();
|
||||
//as the repeated is 8, so there will be 13 groups and enough 10 * 5 = 50 results
|
||||
ASSERT_EQ(50, size);
|
||||
// groups are: (0, 1, 2, 3, 4, null), counts are: (10, 10, 10, 10 ,10, 50)
|
||||
ASSERT_EQ(30, size);
|
||||
|
||||
std::unordered_map<int8_t, int> i8_map;
|
||||
float lastDistance = 0.0;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (std::holds_alternative<int8_t>(group_by_values[i])) {
|
||||
int8_t g_val = std::get<int8_t>(group_by_values[i]);
|
||||
if (group_by_values[i].has_value()) {
|
||||
int8_t g_val = std::get<int8_t>(group_by_values[i].value());
|
||||
i8_map[g_val] += 1;
|
||||
ASSERT_TRUE(i8_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(
|
||||
lastDistance <=
|
||||
distance); //distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
} else {
|
||||
i8_map[-1] += 1;
|
||||
}
|
||||
auto distance = search_result->distances_.at(i);
|
||||
ASSERT_TRUE(
|
||||
lastDistance <=
|
||||
distance); //distance should be decreased as metrics_type is L2
|
||||
lastDistance = distance;
|
||||
}
|
||||
ASSERT_TRUE(i8_map.size() == topK);
|
||||
|
||||
ASSERT_EQ(i8_map.size(), 6);
|
||||
for (const auto& it : i8_map) {
|
||||
ASSERT_TRUE(it.second == group_size);
|
||||
ASSERT_TRUE(it.second == group_size)
|
||||
<< "unexpected count on group " << it.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -537,7 +541,7 @@ TEST(GroupBY, Reduce) {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec_fid = schema->AddDebugField(
|
||||
"fakevec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64);
|
||||
auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true);
|
||||
auto fp16_fid = schema->AddDebugField(
|
||||
"fakevec_fp16", DataType::VECTOR_FLOAT16, dim, knowhere::metric::L2);
|
||||
auto bf16_fid = schema->AddDebugField(
|
||||
@ -572,7 +576,7 @@ TEST(GroupBY, Reduce) {
|
||||
prepareSegmentSystemFieldData(segment1, N, raw_data1);
|
||||
|
||||
//load segment2 raw data
|
||||
for (auto field_data : raw_data2.raw_->fields_data()) {
|
||||
for (auto&& field_data : raw_data2.raw_->fields_data()) {
|
||||
int64_t field_id = field_data.field_id();
|
||||
auto info = FieldDataInfo(field_data.field_id(), N);
|
||||
auto field_meta = fields.at(FieldId(field_id));
|
||||
@ -735,8 +739,8 @@ TEST(GroupBY, GrowingRawData) {
|
||||
std::unordered_set<int32_t> i32_set;
|
||||
float lastDistance = 0.0;
|
||||
for (int j = 0; j < expected_group_count; j++) {
|
||||
if (std::holds_alternative<int32_t>(group_by_values[idx])) {
|
||||
int32_t g_val = std::get<int32_t>(group_by_values[idx]);
|
||||
if (std::holds_alternative<int32_t>(group_by_values[idx].value())) {
|
||||
int32_t g_val = std::get<int32_t>(group_by_values[idx].value());
|
||||
ASSERT_FALSE(
|
||||
i32_set.count(g_val) >
|
||||
0); //as the group_size is 1, there should not be any duplication for group_by value
|
||||
@ -833,8 +837,8 @@ TEST(GroupBY, GrowingIndex) {
|
||||
std::unordered_map<int32_t, int> i32_map;
|
||||
float lastDistance = 0.0;
|
||||
for (int j = 0; j < expected_group_count * group_size; j++) {
|
||||
if (std::holds_alternative<int32_t>(group_by_values[idx])) {
|
||||
int32_t g_val = std::get<int32_t>(group_by_values[idx]);
|
||||
if (std::holds_alternative<int32_t>(group_by_values[idx].value())) {
|
||||
int32_t g_val = std::get<int32_t>(group_by_values[idx].value());
|
||||
i32_map[g_val] += 1;
|
||||
ASSERT_TRUE(i32_map[g_val] <= group_size);
|
||||
auto distance = search_result->distances_.at(idx);
|
||||
|
||||
@ -439,10 +439,6 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
||||
fields := schema.GetFields()
|
||||
for _, field := range fields {
|
||||
if field.Name == groupByFieldName {
|
||||
if field.GetNullable() {
|
||||
ret.err = merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
|
||||
return ret
|
||||
}
|
||||
groupByFieldId = field.FieldID
|
||||
break
|
||||
}
|
||||
|
||||
@ -2797,28 +2797,6 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
||||
Fields: fields,
|
||||
}
|
||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
||||
assert.Nil(t, searchInfo.planInfo)
|
||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
||||
|
||||
normalParam = getValidSearchParams()
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "string_field",
|
||||
})
|
||||
fields = make([]*schemapb.FieldSchema, 0)
|
||||
fields = append(fields, &schemapb.FieldSchema{
|
||||
FieldID: int64(101),
|
||||
Name: "string_field",
|
||||
})
|
||||
fields = append(fields, &schemapb.FieldSchema{
|
||||
FieldID: int64(102),
|
||||
Name: "null_field",
|
||||
Nullable: true,
|
||||
})
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Fields: fields,
|
||||
}
|
||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.NoError(t, searchInfo.parseError)
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user